- Implementing and evaluating a multi-layer perceptron (MLP) and convolutional neural network (CNN) in solving a classification problem
- Building, evaluating, and finetuning a CNN on an image dataset from development to testing
- Tackling overfitting using strategies such as data augmentation and drop out
- Fine tuning a model
- Comparing the performance of a new model with an off-the-shelf model (AlexNet)
- Gaining a deeper understanding of model performance using visualisations from Grad-CAM.
Having a GPU will speed up the training process. See the provided document on Minerva about setting up a working environment for various ways to access a GPU. We highly recommend you use platforms such as Colab.
Please implement the coursework using Python and PyTorch, and refer to the notebooks and exercises provided.
This coursework will use a subset of images from Tiny ImageNet, which is a subset of the ImageNet dataset. Our subset of Tiny ImageNet contains 30 different categories, we will refer to it as TinyImageNet30. The training set has 450 resized images (64x64 pixels) for each category (13,500 images in total). You can download the training and test set from a direct link or the Kaggle challenge website:
Direct access to data is possible by clicking here, please use your university email to access this
import math
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.hub import load_state_dict_from_url
from PIL import Image
import matplotlib.pyplot as plt
import pathlib
import os
import glob
from torch.utils.data import Dataset,DataLoader
import cv2 as cv
import time
from torchcam.methods import SmoothGradCAMpp
import torchvision
from sklearn import metrics
import pandas as pd
from torch import optim
from torch.optim import SGD, Adam
from collections import OrderedDict
from scipy import interp
from sklearn.metrics import roc_curve, auc, f1_score, precision_recall_curve, average_precision_score
from itertools import cycle
import torchvision.models as models
# always check your version
print(torch.__version__)
1.13.0
One challenge of building a deep learning model is to choose an architecture that can learn the features in the dataset without being unnecessarily complex. The first part of the coursework involves building a CNN and training it on TinyImageNet30.
1. Function implementation
Dataset and DataLoader classes Model class for a simple MLP model Model class for a simple CNN model 2. Model training
3. Model Fine-tuning on CIFAR10 dataset
4. Model testing
5. Model comparison
6. Interpretation of results
1. Function implementation
Dataset and DataLoader classesModel class for a simple MLP model Model class for a simple CNN model 2. Model training
3. Model Fine-tuning on CIFAR10 dataset
4. Model testing
5. Model comparison
6. Interpretation of results
# TO COMPLETE
classes = {}
with open("./comp5625M_data_assessment_1/class.txt", "r") as f: # open class file to get all classes and labels
data = f.read().splitlines()
for item in data:
item_class = item.strip("\t").split("\t")
label, category = item_class[0], item_class[1]
classes[category] = label
classes
{'baboon': '0',
'banana': '1',
'bee': '2',
'bison': '3',
'butterfly': '4',
'candle': '5',
'cardigan': '6',
'chihuahua': '7',
'elephant': '8',
'espresso': '9',
'fly': '10',
'goldfish': '11',
'goose': '12',
'grasshopper': '13',
'hourglass': '14',
'icecream': '15',
'ipod': '16',
'jellyfish': '17',
'koala': '18',
'ladybug': '19',
'lion': '20',
'mushroom': '21',
'penguin': '22',
'pig': '23',
'pizza': '24',
'pretzel': '25',
'redpanda': '26',
'refrigerator': '27',
'sombrero': '28',
'umbrella': '29'}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
os.environ['CUDA_VISIBLE_DEVICES'] ='0'
Using device: cuda
train_transformer = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225]),
])
class MyDataset(Dataset):
def __init__(self, data_type, transform=train_transformer):
'''
data_type : ["train_set", "test_set"]
'''
# this is the dictionary path of dataset
root_path = "./comp5625M_data_assessment_1/"
# get type of dataset
self.data_type = data_type
# join the file path above the categories dictionary of the pictures
data_root = pathlib.Path(root_path+self.data_type+"/"+self.data_type)
# obtain all files and all subfiles absolute paths in this dictionary
if self.data_type == "train_set":
all_image_paths = list(data_root.glob("*/*"))
self.all_image_paths = all_image_paths
# find the label from the global variable: classes, based on each path class name
# the function path.parent.name can provide the father dictionary's name, which is the class name
self.all_image_labels = [int(classes[path.parent.name]) for path in all_image_paths]
self.all_image_paths = [str(path) for path in all_image_paths]
self.transform = transform
else:
all_image_paths = list(data_root.glob("*/"))
self.all_image_paths = [str(path) for path in all_image_paths]
# for save result of test csv file
self.all_image_labels = [str(path) for path in all_image_paths]
self.transform = transform
def __getitem__(self, index):
img = cv.imread(self.all_image_paths[index])
img=self.transform(img)
label = self.all_image_labels[index]
return img, label
def __len__(self):
return len(self.all_image_paths)
Create a new model class using a combination of:
# TO COMPLETE
# define a MLP Model class
class MLP_Class(nn.Module):
def __init__(self):
super(MLP_Class,self).__init__()
self.layer = nn.Sequential(
OrderedDict(
[
("flatten", nn.Flatten()),
("hidden_1_layer", nn.Linear(3*64*64,1024)),
('relu1', nn.ReLU()),
("hidden_2_layer", nn.Linear(1024,512)),
('relu2', nn.ReLU()),
("hidden_3_layer", nn.Linear(512, 30)),
]
))
def forward(self,x):
x = self.layer(x)
return x
MLP_model = MLP_Class()
MLP_model = MLP_model.to(device)
print(MLP_model)
MLP_Class(
(layer): Sequential(
(flatten): Flatten(start_dim=1, end_dim=-1)
(hidden_1_layer): Linear(in_features=12288, out_features=1024, bias=True)
(relu1): ReLU()
(hidden_2_layer): Linear(in_features=1024, out_features=512, bias=True)
(relu2): ReLU()
(hidden_3_layer): Linear(in_features=512, out_features=30, bias=True)
)
)
Create a new model class using a combination of:
class CNN_Class(nn.Module):
def __init__(self):
super(CNN_Class,self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.flc1 = nn.Linear(64*8*8,1024)
self.flc2 = nn.Linear(1024,30)
def forward(self,x):
x = self.maxpool1(nn.functional.relu(self.conv1(x)))
x = self.maxpool2(nn.functional.relu(self.conv2(x)))
x = self.maxpool3(nn.functional.relu(self.conv3(x)))
x = x.view(-1,64*8*8)
x = nn.functional.relu(self.flc1(x))
x=self.flc2(x)
return x
CNN_model = CNN_Class()
CNN_model = CNN_model.to(device)
print(CNN_model)
CNN_Class( (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (flc1): Linear(in_features=4096, out_features=1024, bias=True) (flc2): Linear(in_features=1024, out_features=30, bias=True) (softmax): Softmax(dim=1) )
Train your model on the TinyImageNet30 dataset. Split the data into train and validation sets to determine when to stop training. Use seed at 0 for reproducibility and test_ratio=0.2 (validation data)
Display the graph of training and validation loss over epochs and accuracy over epochs to show how you determined the optimal number of training epochs. Top-k accuracy implementation is provided for you below.
Please leave the graph clearly displayed. Please use the same graph to plot graphs for both train and validation.
# split train dataset to train_set and validate_set based on test_ratio=0.2
train_set = MyDataset("train_set")
length=len(train_set)
train_size,validate_size=int(0.8*length),int(0.2*length)
train_set,validate_set=torch.utils.data.random_split(train_set,[train_size,validate_size],generator=torch.Generator().manual_seed(0))
print(len(train_set),len(validate_set))
train_loader = DataLoader(
train_set,
batch_size = 64,
shuffle = True)
validate_loader = DataLoader(
validate_set,
batch_size = 64,
shuffle = True)
10800 2700
def train(train_set, model, criterion, optimizer):
model.train()
n = 0
train_running_loss = 0.0
train_running_accuracy = 0.0
for data in train_set:
images, labels = data
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_running_loss += loss.item()
train_running_accuracy += topk_accuracy(output = outputs, target = labels, topk=(1,))[0].cpu().float()
n += 1
return train_running_loss / n, (train_running_accuracy / n).cpu().numpy()
def validate(val_set, model, criterion, optimizer):
model.eval()
n = 0
validate_running_loss = 0.0
validate_running_accuracy = 0.0
with torch.no_grad():
for data in val_set:
images, labels = data
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
validate_running_loss += loss.item()
validate_running_accuracy += topk_accuracy(output = outputs, target = labels, topk=(1,))[0]
n += 1
return validate_running_loss / n, (validate_running_accuracy / n).cpu().numpy()
# Define top-*k* accuracy
def topk_accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
#TO COMPLETE --> Running you MLP model class
nepochs = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(MLP_model.parameters(), 0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,2, gamma=0.8, last_epoch=-1)
best_loss = 1000.0
train_loss, validate_loss, train_accuracy, validate_accuracy = [], [], [], []
for epoch in range(nepochs):
train_running_loss , train_running_accuracy = train(train_loader, MLP_model, criterion, optimizer)
train_loss.append(train_running_loss)
train_accuracy.append(train_running_accuracy)
validate_running_loss , validate_running_accuracy = validate(validate_loader, MLP_model, criterion, optimizer)
validate_loss.append(validate_running_loss)
validate_accuracy.append(validate_running_accuracy)
scheduler.step()
if validate_running_loss < best_loss:
best_loss = validate_running_loss
torch.save(MLP_model.state_dict(), './MLP_model.pt')
# Your graph
x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(2,1,sharex=True,sharey=False)
fig.suptitle('The loss and accuracy of train and validate sets')
axs[0].plot(x_axis,train_loss,'r--',label='MLP_train_loss')
axs[0].plot(x_axis,validate_loss,'g--',label='MLP_validate_loss')
axs[1].plot(x_axis,train_accuracy,'b--',label='MLP_train_accuracy')
axs[1].plot(x_axis,validate_accuracy,'y--',label='MLP_validate_accuracy')
axs[1].set_xlabel('epoch')
axs[0].set_ylabel('loss')
axs[1].set_ylabel('percentage of accuracy')
axs[0].legend()
axs[1].legend()
plt.show()
#TO COMPLETE --> Running you CNN model class
nepochs = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(CNN_model.parameters(), 0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,2, gamma=0.8, last_epoch=-1)
best_loss = 1000.0
CNN_train_loss, CNN_validate_loss, CNN_train_accuracy, CNN_validate_accuracy = [], [], [], []
for epoch in range(nepochs):
train_running_loss , train_running_accuracy = train(train_loader, CNN_model, criterion, optimizer)
CNN_train_loss.append(train_running_loss)
CNN_train_accuracy.append(train_running_accuracy)
validate_running_loss , validate_running_accuracy = validate(validate_loader, CNN_model, criterion, optimizer)
CNN_validate_loss.append(validate_running_loss)
CNN_validate_accuracy.append(validate_running_accuracy)
scheduler.step()
if validate_running_loss < best_loss:
best_loss = validate_running_loss
torch.save(CNN_model.state_dict(), './CNN_model.pt')
# Your graph
x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(2,1,sharex=True,sharey=False)
fig.suptitle('The loss and accuracy of train and validate sets')
axs[0].plot(x_axis,CNN_train_loss,'r--',label='CNN_train_loss')
axs[0].plot(x_axis,CNN_validate_loss,'g--',label='CNN_validate_loss')
axs[1].plot(x_axis,CNN_train_accuracy,'b--',label='CNN_train_accuracy')
axs[1].plot(x_axis,CNN_validate_accuracy,'y--',label='CNN_validate_accuracy')
axs[1].set_xlabel('epoch')
axs[0].set_ylabel('loss')
axs[1].set_ylabel('percentage of accuracy')
axs[0].legend()
axs[1].legend()
plt.show()
Comment on your model and results that should include number of parameters in each model and why CNN over MLP for image classification task?
# Your code here!
def accuracy(cnfm):
return cnfm.trace()/cnfm.sum((0,1))
def recalls(cnfm):
return np.diag(cnfm)/cnfm.sum(1)
def precisions(cnfm):
return np.diag(cnfm)/cnfm.sum(0)
num_class = len(classes)
CNN_model.load_state_dict(torch.load('./CNN_model.pt'))
nclasses = len(classes)
cnfm = np.zeros((nclasses,nclasses),dtype=int)
score_list = [] # save predicted score
label_list = []
with torch.no_grad():
for data in validate_loader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
outputs = CNN_model(images)
_, predicted = torch.max(outputs, 1)
score_tmp = outputs
for i in range(labels.size(0)):
cnfm[labels[i].item(),predicted[i].item()] += 1
score_list.extend(score_tmp.detach().cpu().numpy())
label_list.extend(labels.cpu().numpy())
print("Confusion matrix")
print(cnfm)
# show confusion matrix as a grey-level image
plt.imshow(cnfm, cmap='gray')
# show per-class recall and precision
print(f"Accuracy: {accuracy(cnfm) :.1%}")
r = recalls(cnfm)
p = precisions(cnfm)
for i in range(nclasses):
print(f"Class {list(classes.keys())[i]} : Precision {p[i] :.1%} Recall {r[i] :.1%}")
Confusion matrix [[30 0 1 4 0 1 3 3 7 1 1 0 3 0 3 1 1 1 7 2 5 1 1 11 0 0 3 0 0 0] [ 1 29 8 1 0 4 0 0 0 1 1 3 0 0 3 0 2 0 1 2 0 1 3 0 5 2 0 0 3 1] [ 0 1 42 2 8 0 0 1 0 1 10 1 0 7 1 0 1 0 0 9 0 1 0 1 0 1 1 0 0 3] [ 1 1 0 54 0 0 2 1 9 1 0 0 2 1 0 0 2 0 2 0 0 1 2 9 0 0 3 0 1 0] [ 0 0 6 0 79 2 0 1 0 0 2 1 0 4 0 0 0 0 1 1 0 1 0 0 0 1 0 0 2 3] [ 0 7 5 0 0 31 1 7 0 1 1 3 0 0 8 0 3 3 0 4 0 1 2 1 0 1 2 0 4 4] [ 0 2 4 1 1 1 40 2 2 0 0 2 0 4 3 0 4 0 4 0 1 1 5 1 0 0 2 2 3 7] [ 2 0 2 1 0 3 0 32 1 2 0 1 3 0 2 4 5 0 4 3 6 1 0 5 0 2 5 3 2 2] [ 2 0 0 14 0 0 2 0 44 1 0 0 1 1 0 0 3 0 4 1 2 1 1 6 0 0 2 1 2 1] [ 0 2 2 0 0 2 0 3 0 43 1 4 0 0 7 1 7 0 0 2 0 1 1 0 3 3 1 1 3 0] [ 1 0 14 1 4 0 1 1 0 0 45 0 1 11 2 1 0 0 0 10 0 1 0 0 1 0 0 0 2 1] [ 0 3 1 0 0 2 0 1 2 2 0 54 0 3 0 0 0 3 0 6 3 4 1 1 0 0 2 0 2 2] [ 1 0 2 1 1 1 2 3 3 1 2 0 32 1 2 1 4 3 1 1 4 2 4 9 0 0 2 0 1 4] [ 0 3 15 2 3 1 1 1 2 0 7 1 2 46 0 0 1 0 2 4 1 0 1 0 0 0 0 1 0 0] [ 0 3 2 0 0 11 2 4 0 0 2 0 2 2 41 0 5 1 2 1 0 0 4 0 1 0 0 2 1 4] [ 0 3 0 0 1 8 1 4 1 8 3 4 1 1 4 12 3 1 0 5 2 3 0 4 4 8 0 1 9 3] [ 1 2 0 2 0 1 1 3 0 2 1 0 3 1 4 2 34 0 0 0 1 1 2 1 0 1 0 3 4 2] [ 0 1 0 0 0 2 0 0 0 0 1 2 0 1 2 0 0 68 0 3 0 0 0 0 0 1 0 0 0 2] [ 3 0 1 3 1 0 1 2 5 0 1 0 2 0 0 0 0 0 62 0 2 2 0 3 0 1 2 0 0 2] [ 1 6 17 2 1 0 0 3 0 1 4 1 1 12 1 1 2 0 0 40 1 1 1 0 0 0 1 1 0 0] [ 4 1 3 0 0 1 1 6 5 0 0 0 1 2 2 0 0 0 4 1 38 5 0 10 0 4 1 2 2 2] [ 0 1 3 2 4 4 0 2 3 2 1 3 1 3 0 0 1 1 1 1 1 35 0 2 1 0 3 0 5 2] [ 0 1 2 3 1 1 3 4 5 0 0 0 7 0 4 1 0 3 0 0 0 0 35 1 0 0 0 3 3 2] [ 5 0 0 6 0 0 0 4 8 0 1 1 6 2 1 0 0 1 3 0 2 5 2 33 4 1 2 3 1 2] [ 0 3 0 0 1 5 0 1 0 0 0 1 0 3 0 2 0 0 0 0 0 2 0 1 55 6 0 1 4 2] [ 0 3 2 2 1 5 1 3 0 6 1 2 0 0 2 2 1 0 1 2 3 0 0 0 15 26 1 2 5 0] [ 6 0 3 4 0 0 0 3 3 0 1 0 0 0 1 0 0 0 5 3 1 9 0 0 0 1 52 0 0 2] [ 1 3 0 0 0 4 3 1 2 2 0 1 2 0 12 1 8 0 2 0 2 2 3 2 1 2 1 28 2 6] [ 1 1 1 4 0 4 5 7 3 0 0 1 1 3 3 2 3 1 0 1 2 4 0 5 0 3 2 1 36 6] [ 0 7 2 1 2 0 3 2 2 0 2 1 4 5 3 3 6 3 2 4 1 5 2 3 1 0 1 1 8 22]] Accuracy: 45.1% Class baboon : Precision 50.0% Recall 33.3% Class banana : Precision 34.9% Recall 40.8% Class bee : Precision 30.4% Recall 46.2% Class bison : Precision 49.1% Recall 58.7% Class butterfly : Precision 73.1% Recall 76.0% Class candle : Precision 33.0% Recall 34.8% Class cardigan : Precision 54.8% Recall 43.5% Class chihuahua : Precision 30.5% Recall 35.2% Class elephant : Precision 41.1% Recall 49.4% Class espresso : Precision 57.3% Recall 49.4% Class fly : Precision 51.1% Recall 46.4% Class goldfish : Precision 62.1% Recall 58.7% Class goose : Precision 42.7% Recall 36.4% Class grasshopper : Precision 40.7% Recall 48.9% Class hourglass : Precision 36.9% Recall 45.6% Class icecream : Precision 35.3% Recall 12.8% Class ipod : Precision 35.4% Recall 47.2% Class jellyfish : Precision 76.4% Recall 81.9% Class koala : Precision 57.4% Recall 66.7% Class ladybug : Precision 37.7% Recall 40.8% Class lion : Precision 48.7% Recall 40.0% Class mushroom : Precision 38.5% Recall 42.7% Class penguin : Precision 50.0% Recall 44.3% Class pig : Precision 30.3% Recall 35.5% Class pizza : Precision 60.4% Recall 63.2% Class pretzel : Precision 40.6% Recall 30.2% Class redpanda : Precision 58.4% Recall 55.3% Class refrigerator : Precision 50.0% Recall 30.8% Class sombrero : Precision 34.3% Recall 36.0% Class umbrella : Precision 25.3% Recall 22.9%
score_array = np.array(score_list)
# make label convert to be onehot form
label_tensor = torch.tensor(label_list)
label_tensor = label_tensor.reshape((label_tensor.shape[0], 1))
label_onehot = torch.zeros(label_tensor.shape[0], num_class)
label_onehot.scatter_(dim=1, index=label_tensor, value=1)
label_onehot = np.array(label_onehot)
# call sklearn to calculate the corresponding fpr and tpr of each class
fpr_dict = dict()
tpr_dict = dict()
roc_auc_dict = dict()
for i in range(num_class):
fpr_dict[i], tpr_dict[i], _ = roc_curve(label_onehot[:, i], score_array[:, i])
roc_auc_dict[i] = auc(fpr_dict[i], tpr_dict[i])
# micro
fpr_dict["micro"], tpr_dict["micro"], _ = roc_curve(label_onehot.ravel(), score_array.ravel())
roc_auc_dict["micro"] = auc(fpr_dict["micro"], tpr_dict["micro"])
# macro
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(num_class)]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(num_class):
mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i])
# Finally average it and compute AUC
mean_tpr /= num_class
fpr_dict["macro"] = all_fpr
tpr_dict["macro"] = mean_tpr
roc_auc_dict["macro"] = auc(fpr_dict["macro"], tpr_dict["macro"])
roc_auc_dict_order=sorted(roc_auc_dict.items(),key=lambda x:x[1],reverse=True)
# draw the average roc curve of all classes
plt.figure()
lw = 2
plt.plot(fpr_dict["micro"], tpr_dict["micro"],
label='micro-average ROC curve (area = {0:0.2f})'
''.format(roc_auc_dict["micro"]),
color='deeppink', linestyle=':', linewidth=4)
plt.plot(fpr_dict["macro"], tpr_dict["macro"],
label='macro-average ROC curve (area = {0:0.2f})'
''.format(roc_auc_dict["macro"]),
color='navy', linestyle=':', linewidth=4)
colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(5), colors):
category = roc_auc_dict_order[i][0]
plt.plot(fpr_dict[category], tpr_dict[category], color=color, lw=lw,
label='ROC curve of class {0} (area = {1:0.2f})'
''.format(category, roc_auc_dict_order[i][1]))
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.savefig('set113_roc.jpg')
plt.show()
C:\Users\Administrator\AppData\Local\Temp\ipykernel_9780\3326775581.py:32: DeprecationWarning: scipy.interp is deprecated and will be removed in SciPy 2.0.0, use numpy.interp instead mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i])
Note: All parts below here relate to the CNN model only and not the MLP! You are advised to use your final CNN model only for each of the following parts.
Using your (final) CNN model, use the strategies below to avoid overfitting. You can reuse the network weights from previous training, often referred to as fine tuning.
Plot loss and accuracy graphs per epoch side by side for each implemented strategy.
Implement at least five different data augmentation techniques that should include both photometric and geometric augmentations.
Provide graph and comment on what you observe
data_augmentation_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomRotation((-20,20)),
transforms.ColorJitter(hue=0.2, saturation=0.2, brightness=0.2),
transforms.RandomResizedCrop(64,scale=(0.7,1.0)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
train_set = MyDataset("train_set",transform=data_augmentation_transform)
length=len(train_set)
train_size,validate_size=int(0.8*length),int(0.2*length)
train_set,validate_set=torch.utils.data.random_split(train_set,[train_size,validate_size],generator=torch.Generator().manual_seed(0))
train_loader = DataLoader(
train_set,
batch_size = 64,
shuffle = True)
validate_loader = DataLoader(
validate_set,
batch_size = 64,
shuffle = True)
# Your code here!
nepochs = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(CNN_model.parameters(), 0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,2, gamma=0.8, last_epoch=-1)
best_loss = 1000.0
CNN_train_loss, CNN_validate_loss, CNN_train_accuracy, CNN_validate_accuracy = [], [], [], []
for epoch in range(nepochs):
train_running_loss , train_running_accuracy = train(train_loader, CNN_model, criterion, optimizer)
CNN_train_loss.append(train_running_loss)
CNN_train_accuracy.append(train_running_accuracy)
validate_running_loss , validate_running_accuracy = validate(validate_loader, CNN_model, criterion, optimizer)
CNN_validate_loss.append(validate_running_loss)
CNN_validate_accuracy.append(validate_running_accuracy)
scheduler.step()
if validate_running_loss < best_loss:
best_loss = validate_running_loss
torch.save(CNN_model.state_dict(), './CNN_model.pt')
x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(2,1,sharex=True,sharey=False)
fig.suptitle('The loss and accuracy of train and validate sets')
axs[0].plot(x_axis,CNN_train_loss,'r--',label='CNN_train_loss')
axs[0].plot(x_axis,CNN_validate_loss,'g--',label='CNN_validate_loss')
axs[1].plot(x_axis,CNN_train_accuracy,'b--',label='CNN_train_accuracy')
axs[1].plot(x_axis,CNN_validate_accuracy,'y--',label='CNN_validate_accuracy')
axs[1].set_xlabel('epoch')
axs[0].set_ylabel('loss')
axs[1].set_ylabel('percentage of accuracy')
axs[0].legend()
axs[1].legend()
plt.show()
Before data augmentation, the training set was overfitting after a dozen epochs. the training set was approaching 100% too early, but the accuracy of the validation set was no longer increasing, proving that the model was learning too much useless information. By applying photometric and geometric augmentations to the images, the overfitting problem was solved to some extent and the accuracy of the validation set improved.
Implement dropout in your model
Provide graph and comment on your choice of proportion used
# Your code here!
possibility = [0.2,0.3,0.4,0.5,0.6,0.7]
CNN_train_loss = {"0.2":[],"0.3":[],"0.4":[],"0.5":[],"0.6":[],"0.7":[]}
CNN_validate_loss = {"0.2":[],"0.3":[],"0.4":[],"0.5":[],"0.6":[],"0.7":[]}
CNN_train_accuracy = {"0.2":[],"0.3":[],"0.4":[],"0.5":[],"0.6":[],"0.7":[]}
CNN_validate_accuracy = {"0.2":[],"0.3":[],"0.4":[],"0.5":[],"0.6":[],"0.7":[]}
for poss in possibility:
class CNN_Class_Improved(nn.Module):
def __init__(self):
super(CNN_Class_Improved,self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.flc1 = nn.Linear(64*8*8,1024)
self.dropout = nn.Dropout(p=poss)
self.flc2 = nn.Linear(1024,30)
def forward(self,x):
x = self.maxpool1(nn.functional.relu(self.conv1(x)))
x = self.maxpool2(nn.functional.relu(self.conv2(x)))
x = self.maxpool3(nn.functional.relu(self.conv3(x)))
x = x.view(-1,64*8*8)
x = self.dropout(x)
x = nn.functional.relu(self.flc1(x))
x = self.flc2(x)
return x
CNN_model_Improved =CNN_Class_Improved()
CNN_model_Improved = CNN_model_Improved.to(device)
nepochs = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(CNN_model_Improved.parameters(), 0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,2, gamma=0.8, last_epoch=-1)
best_loss = 1000.0
for epoch in range(nepochs):
train_running_loss , train_running_accuracy = train(train_loader, CNN_model_Improved, criterion, optimizer)
CNN_train_loss[str(poss)].append(train_running_loss)
CNN_train_accuracy[str(poss)].append(train_running_accuracy)
validate_running_loss , validate_running_accuracy = validate(validate_loader, CNN_model_Improved, criterion, optimizer)
CNN_validate_loss[str(poss)].append(validate_running_loss)
CNN_validate_accuracy[str(poss)].append(validate_running_accuracy)
scheduler.step()
if validate_running_loss < best_loss:
best_loss = validate_running_loss
torch.save(CNN_model_Improved.state_dict(), './CNN_model_Improved.pt')
x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(2,1,figsize=(15,20),sharex=True,sharey=False)
fig.suptitle('The loss and accuracy of train and validate sets with dropout possibility')
for poss in possibility:
axs[0].plot(x_axis,CNN_train_loss[str(poss)],label='CNN_train_loss in '+str(poss))
axs[0].plot(x_axis,CNN_validate_loss[str(poss)],label='CNN_validate_loss in '+str(poss))
axs[1].plot(x_axis,CNN_train_accuracy[str(poss)],label='CNN_train_accuracy in '+str(poss))
axs[1].plot(x_axis,CNN_validate_accuracy[str(poss)],label='CNN_validate_accuracy in '+str(poss))
axs[1].set_xlabel('epoch')
axs[0].set_ylabel('loss')
axs[1].set_ylabel('percentage of accuracy')
axs[0].legend()
axs[1].legend()
plt.show()
As can be seen from the two plots above, when the probability of the Dropout layer is set to 0.4, the Loss is lower and Accuracy is higher for both the training and validation sets. Therefore I can tell that p=0.4 is more suitable for this CNN model. Below I redefine the model Class as CNN_Class_Improved, where the Dropout layer is set with a probability of 0.4.
class CNN_Class_Improved(nn.Module):
def __init__(self):
super(CNN_Class_Improved,self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.flc1 = nn.Linear(64*8*8,1024)
self.dropout = nn.Dropout(p=0.4)
self.flc2 = nn.Linear(1024,30)
def forward(self,x):
x = self.maxpool1(nn.functional.relu(self.conv1(x)))
x = self.maxpool2(nn.functional.relu(self.conv2(x)))
x = self.maxpool3(nn.functional.relu(self.conv3(x)))
x = x.view(-1,64*8*8)
x = self.dropout(x)
x = nn.functional.relu(self.flc1(x))
x = self.flc2(x)
return x
CNN_model_Improved =CNN_Class_Improved()
CNN_model_Improved = CNN_model_Improved.to(device)
CNN_train_loss, CNN_validate_loss, CNN_train_accuracy, CNN_validate_accuracy = [], [], [], []
nepochs = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(CNN_model_Improved.parameters(), 0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,2, gamma=0.8, last_epoch=-1)
best_loss = 1000.0
for epoch in range(nepochs):
train_running_loss , train_running_accuracy = train(train_loader, CNN_model_Improved, criterion, optimizer)
CNN_train_loss.append(train_running_loss)
CNN_train_accuracy.append(train_running_accuracy)
validate_running_loss , validate_running_accuracy = validate(validate_loader, CNN_model_Improved, criterion, optimizer)
CNN_validate_loss.append(validate_running_loss)
CNN_validate_accuracy.append(validate_running_accuracy)
scheduler.step()
if validate_running_loss < best_loss:
best_loss = validate_running_loss
torch.save(CNN_model_Improved.state_dict(), './CNN_model_Improved.pt')
x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(2,1,figsize=(15,20),sharex=True,sharey=False)
fig.suptitle('The loss and accuracy of train and validate sets with 0.4 dropout possibility')
axs[0].plot(x_axis,CNN_train_loss,label='CNN_train_loss')
axs[0].plot(x_axis,CNN_validate_loss,label='CNN_validate_loss')
axs[1].plot(x_axis,CNN_train_accuracy,label='CNN_train_accuracy')
axs[1].plot(x_axis,CNN_validate_accuracy,label='CNN_validate_accuracy')
axs[1].set_xlabel('epoch')
axs[0].set_ylabel('loss')
axs[1].set_ylabel('percentage of accuracy')
axs[0].legend()
axs[1].legend()
plt.show()
Use learning rates [0.1, 0.001, 0.0001]
Provide separate graphs for loss and accuracy, each showing performance at three different learning rates
nepochs = 100
criterion = nn.CrossEntropyLoss()
best_loss = 1000.0
train_loss_lr_01 = []
train_acc_lr_01 = []
validate_loss_lr_01 = []
validate_acc_lr_01 = []
optimizer = torch.optim.Adam(CNN_model_Improved.parameters(), lr=0.1)
for epoch in range(nepochs):
train_running_loss , train_running_accuracy = train(train_loader, CNN_model_Improved, criterion, optimizer)
train_loss_lr_01.append(train_running_loss)
train_acc_lr_01.append(train_running_accuracy)
validate_running_loss , validate_running_accuracy = validate(validate_loader, CNN_model_Improved, criterion, optimizer)
validate_loss_lr_01.append(validate_running_loss)
validate_acc_lr_01.append(validate_running_accuracy)
if validate_running_loss < best_loss:
best_loss = validate_running_loss
torch.save(CNN_model_Improved.state_dict(), './CNN_model_Improved.pt')
print('epoch:{},train_loss:{}, train_acc:{}, val_loss:{}, val_acc:{}'.format(epoch+1,train_running_loss,train_running_accuracy,validate_running_loss,validate_running_accuracy))
epoch:1,train_loss:2164.766256813467, train_acc:3.2513561248779297, val_loss:3.4151479754337046, val_acc:3.3793604373931885 epoch:2,train_loss:3.417014720171866, train_acc:3.155818462371826, val_loss:3.4172210637913194, val_acc:3.1371123790740967 epoch:3,train_loss:3.4147649728334866, train_acc:3.260601758956909, val_loss:3.4149644984755407, val_acc:3.5004844665527344 epoch:4,train_loss:3.4186991643623488, train_acc:3.026380777359009, val_loss:3.4159545177637143, val_acc:3.4156975746154785 epoch:5,train_loss:3.4175619698135105, train_acc:3.2575197219848633, val_loss:3.415543012840803, val_acc:3.2703487873077393 epoch:6,train_loss:3.417246429172493, train_acc:3.125, val_loss:3.4096909234690114, val_acc:3.16133713722229 epoch:7,train_loss:3.418316842536249, train_acc:3.0171351432800293, val_loss:3.417351800341939, val_acc:3.5368216037750244 epoch:8,train_loss:3.418994481746967, train_acc:3.260601758956909, val_loss:3.4151616262835125, val_acc:3.633720874786377 epoch:9,train_loss:3.420121513174836, train_acc:3.0448715686798096, val_loss:3.414994395056436, val_acc:3.4883720874786377 epoch:10,train_loss:3.4162561258620765, train_acc:3.402366876602173, val_loss:3.4193347165750905, val_acc:2.579941749572754 epoch:11,train_loss:3.419583608412884, train_acc:2.8907790184020996, val_loss:3.426316582879355, val_acc:3.3430233001708984 epoch:12,train_loss:3.4176741506926405, train_acc:3.590359926223755, val_loss:3.4160396331964535, val_acc:3.633720874786377 epoch:13,train_loss:3.4167988328538703, train_acc:3.3099112510681152, val_loss:3.4130693701810615, val_acc:3.4156975746154785 epoch:14,train_loss:3.419260002452241, train_acc:2.98939847946167, val_loss:3.4215198117633197, val_acc:3.16133713722229 epoch:15,train_loss:3.417573000552386, train_acc:3.1342456340789795, val_loss:3.4205423454905666, val_acc:3.3066859245300293 epoch:16,train_loss:3.418926923232671, train_acc:3.0602810382843018, val_loss:3.4173313074333724, val_acc:3.4641470909118652 epoch:17,train_loss:3.416952684786193, train_acc:3.174309492111206, val_loss:3.4297832500102907, val_acc:3.1734495162963867 epoch:18,train_loss:3.419816007275553, train_acc:3.4054486751556396, val_loss:3.416986337927885, val_acc:3.427809953689575 epoch:19,train_loss:3.4161838441205448, train_acc:3.2051284313201904, val_loss:3.413021359332772, val_acc:3.3430233001708984 epoch:20,train_loss:3.4196937112413215, train_acc:3.248274087905884, val_loss:3.4172293363615522, val_acc:3.4520349502563477 epoch:21,train_loss:3.417763116091666, train_acc:3.568787097930908, val_loss:3.414470162502555, val_acc:3.4156975746154785 epoch:22,train_loss:3.4174171650903467, train_acc:3.2729289531707764, val_loss:3.4207970042561375, val_acc:3.16133713722229 epoch:23,train_loss:3.4184304313546807, train_acc:3.282174587249756, val_loss:3.414499371550804, val_acc:3.6579458713531494 epoch:24,train_loss:3.41857082321799, train_acc:3.1342456340789795, val_loss:3.4210255312365154, val_acc:3.234011650085449 epoch:25,train_loss:3.416219210483619, train_acc:3.4609220027923584, val_loss:3.4134753526643267, val_acc:3.4520349502563477 epoch:26,train_loss:3.4197405767158644, train_acc:3.088017702102661, val_loss:3.4153738520866215, val_acc:3.5731587409973145 epoch:27,train_loss:3.420783406884007, train_acc:3.4270217418670654, val_loss:3.4164660808651948, val_acc:2.579941749572754 epoch:28,train_loss:3.4167481569143443, train_acc:3.2051284313201904, val_loss:3.4207183094911797, val_acc:3.1734495162963867 epoch:29,train_loss:3.4205114516986193, train_acc:2.847633123397827, val_loss:3.418117695076521, val_acc:3.4156975746154785 epoch:30,train_loss:3.4190025611742008, train_acc:2.933925151824951, val_loss:3.418670748555383, val_acc:3.4520349502563477 epoch:31,train_loss:3.4182936806650557, train_acc:3.072608470916748, val_loss:3.4141087698382, val_acc:2.616279125213623 epoch:32,train_loss:3.4167870247857812, train_acc:3.52255916595459, val_loss:3.410075542538665, val_acc:2.773740291595459 epoch:33,train_loss:3.4175721456313273, train_acc:3.158900499343872, val_loss:3.4164779851602955, val_acc:3.427809953689575 epoch:34,train_loss:3.4188687293487185, train_acc:3.2174556255340576, val_loss:3.4167783703914907, val_acc:3.7306201457977295 epoch:35,train_loss:3.4183290385635647, train_acc:3.026380777359009, val_loss:3.4241222503573394, val_acc:3.125 epoch:36,train_loss:3.416948575239915, train_acc:3.0602810382843018, val_loss:3.4283005470453305, val_acc:3.125 epoch:37,train_loss:3.422103128489658, train_acc:3.161982297897339, val_loss:3.4196407850398574, val_acc:3.4641470909118652 epoch:38,train_loss:3.420339352985811, train_acc:3.0787723064422607, val_loss:3.411231368087059, val_acc:2.616279125213623 epoch:39,train_loss:3.4194834387514015, train_acc:3.1527366638183594, val_loss:3.425304961758991, val_acc:3.3430233001708984 epoch:40,train_loss:3.4183861207679884, train_acc:3.075690507888794, val_loss:3.4169876187346704, val_acc:3.3066859245300293 epoch:41,train_loss:3.41841212532224, train_acc:3.0109713077545166, val_loss:3.4204413391822994, val_acc:3.4641470909118652 epoch:42,train_loss:3.419490743670943, train_acc:3.0664448738098145, val_loss:3.4204198648763255, val_acc:3.3430233001708984 epoch:43,train_loss:3.417487470355965, train_acc:3.2359466552734375, val_loss:3.419445808543715, val_acc:3.3430233001708984 epoch:44,train_loss:3.4181505279428155, train_acc:3.3592207431793213, val_loss:3.422358607136926, val_acc:3.4520349502563477 epoch:45,train_loss:3.4212284158672808, train_acc:3.3099112510681152, val_loss:3.4102145461148994, val_acc:3.16133713722229 epoch:46,train_loss:3.4180980814984565, train_acc:3.023298740386963, val_loss:3.4169417314751205, val_acc:2.9796512126922607 epoch:47,train_loss:3.4190124929303956, train_acc:3.2051284313201904, val_loss:3.4161524384520776, val_acc:3.5368216037750244 epoch:48,train_loss:3.415075201960005, train_acc:3.23286509513855, val_loss:3.423741251923317, val_acc:3.633720874786377 epoch:49,train_loss:3.4189514478988197, train_acc:3.2883384227752686, val_loss:3.4172512154246486, val_acc:3.015988349914551 epoch:50,train_loss:3.417860299172486, train_acc:3.6088509559631348, val_loss:3.4190574967583944, val_acc:3.1734495162963867 epoch:51,train_loss:3.416621050185706, train_acc:3.137327194213867, val_loss:3.4166778298311455, val_acc:3.391472816467285 epoch:52,train_loss:3.418241732219267, train_acc:3.534886360168457, val_loss:3.417380976122479, val_acc:2.579941749572754 epoch:53,train_loss:3.4201752236608924, train_acc:3.161982297897339, val_loss:3.4176531337028324, val_acc:2.616279125213623 epoch:54,train_loss:3.417625130986321, train_acc:3.5965237617492676, val_loss:3.4332177361776663, val_acc:3.3066859245300293 epoch:55,train_loss:3.418752680163412, train_acc:3.094181537628174, val_loss:3.4256202620129255, val_acc:3.3793604373931885 epoch:56,train_loss:3.4195547611755734, train_acc:3.6612424850463867, val_loss:3.42174948093503, val_acc:3.4156975746154785 epoch:57,train_loss:3.4177634631388285, train_acc:3.23286509513855, val_loss:3.4214947944463687, val_acc:2.616279125213623 epoch:58,train_loss:3.421686240201871, train_acc:2.9770710468292236, val_loss:3.411846914956736, val_acc:3.4520349502563477 epoch:59,train_loss:3.4174652720344136, train_acc:3.5410501956939697, val_loss:3.4218920275222424, val_acc:2.8706395626068115 epoch:60,train_loss:3.420347655313255, train_acc:3.245192289352417, val_loss:3.4128389635751413, val_acc:2.9796512126922607 epoch:61,train_loss:3.4182452534782817, train_acc:2.9400887489318848, val_loss:3.422146891438684, val_acc:3.3430233001708984 epoch:62,train_loss:3.415685522485767, train_acc:3.2513561248779297, val_loss:3.4185479352640553, val_acc:3.4156975746154785 epoch:63,train_loss:3.417576566955747, train_acc:3.094181537628174, val_loss:3.4088850853055024, val_acc:3.234011650085449 epoch:64,train_loss:3.4182081278964613, train_acc:2.866124153137207, val_loss:3.416826420052107, val_acc:3.318798303604126 epoch:65,train_loss:3.418849708060541, train_acc:3.2051284313201904, val_loss:3.408453292624895, val_acc:3.3430233001708984 epoch:66,train_loss:3.4193298774358083, train_acc:3.226701259613037, val_loss:3.4256949535636014, val_acc:2.9796512126922607 epoch:67,train_loss:3.417695463056395, train_acc:3.3099112510681152, val_loss:3.415983948596688, val_acc:3.4641470909118652 epoch:68,train_loss:3.416713549540593, train_acc:3.260601758956909, val_loss:3.413653795109239, val_acc:3.318798303604126 epoch:69,train_loss:3.4188056514107967, train_acc:3.1650640964508057, val_loss:3.42025640398957, val_acc:3.439922571182251 epoch:70,train_loss:3.4203963942781708, train_acc:3.2544379234313965, val_loss:3.4184001656465752, val_acc:3.125 epoch:71,train_loss:3.4179979724996894, train_acc:3.174309492111206, val_loss:3.4074795856032263, val_acc:3.015988349914551 epoch:72,train_loss:3.417654975631533, train_acc:3.155818462371826, val_loss:3.417345950769824, val_acc:3.5004844665527344 epoch:73,train_loss:3.4162017861766927, train_acc:3.1681461334228516, val_loss:3.420057013977406, val_acc:3.3430233001708984 epoch:74,train_loss:3.419482875857833, train_acc:3.075690507888794, val_loss:3.4232967398887455, val_acc:3.3430233001708984 epoch:75,train_loss:3.4189944196734907, train_acc:3.229782819747925, val_loss:3.4140212092288706, val_acc:3.1371123790740967 epoch:76,train_loss:3.4205527813476926, train_acc:3.1465728282928467, val_loss:3.427143795545711, val_acc:2.737403154373169 epoch:77,train_loss:3.4187424592012485, train_acc:3.276010751724243, val_loss:3.4177852231402728, val_acc:2.579941749572754 epoch:78,train_loss:3.419346115292882, train_acc:3.3592207431793213, val_loss:3.4079314442568047, val_acc:3.4156975746154785 epoch:79,train_loss:3.419485705844044, train_acc:3.1804733276367188, val_loss:3.4205848228099733, val_acc:3.3430233001708984 epoch:80,train_loss:3.4190032087134186, train_acc:3.2852563858032227, val_loss:3.4090217435082724, val_acc:3.427809953689575 epoch:81,train_loss:3.418792549675033, train_acc:3.1804733276367188, val_loss:3.411422552064408, val_acc:3.3793604373931885 epoch:82,train_loss:3.4182332690650896, train_acc:3.00788950920105, val_loss:3.4216125122336454, val_acc:3.4156975746154785 epoch:83,train_loss:3.417820659614879, train_acc:3.2575197219848633, val_loss:3.411681385927422, val_acc:3.948643445968628 epoch:84,train_loss:3.4190039352552426, train_acc:3.0633628368377686, val_loss:3.412457349688508, val_acc:2.8706395626068115 epoch:85,train_loss:3.418832723910992, train_acc:2.9462523460388184, val_loss:3.414388662160829, val_acc:2.579941749572754 epoch:86,train_loss:3.4185697750226987, train_acc:3.4085307121276855, val_loss:3.416147836419039, val_acc:3.1371123790740967 epoch:87,train_loss:3.4195547343711175, train_acc:3.3438117504119873, val_loss:3.413477354271467, val_acc:3.16133713722229 epoch:88,train_loss:3.4165763332998966, train_acc:3.4763314723968506, val_loss:3.4125312040018483, val_acc:3.6458332538604736 epoch:89,train_loss:3.416477335980658, train_acc:3.294501781463623, val_loss:3.4151419373445733, val_acc:3.5004844665527344 epoch:90,train_loss:3.416105400175738, train_acc:3.41469407081604, val_loss:3.418804551279822, val_acc:3.4156975746154785 epoch:91,train_loss:3.419120620693681, train_acc:3.3191568851470947, val_loss:3.4148293206858082, val_acc:3.015988349914551 epoch:92,train_loss:3.417863178535326, train_acc:3.433185338973999, val_loss:3.4063180546427883, val_acc:3.5247092247009277 epoch:93,train_loss:3.4200609520342224, train_acc:3.3253207206726074, val_loss:3.413665455441142, val_acc:3.015988349914551 epoch:94,train_loss:3.4180003157734165, train_acc:3.1003451347351074, val_loss:3.416378331738849, val_acc:3.779069662094116 epoch:95,train_loss:3.421211277944802, train_acc:3.174309492111206, val_loss:3.416663186494694, val_acc:3.5368216037750244 epoch:96,train_loss:3.4179833626606055, train_acc:3.109590530395508, val_loss:3.41137527310571, val_acc:3.561046600341797 epoch:97,train_loss:3.4203324515438642, train_acc:3.23286509513855, val_loss:3.412340297255405, val_acc:3.3430233001708984 epoch:98,train_loss:3.419171688824716, train_acc:3.1989645957946777, val_loss:3.4238855949667997, val_acc:2.9796512126922607 epoch:99,train_loss:3.4177519950640978, train_acc:3.276010751724243, val_loss:3.417204640632452, val_acc:3.4641470909118652 epoch:100,train_loss:3.4182146978096144, train_acc:3.2667651176452637, val_loss:3.4184995030247887, val_acc:3.3066859245300293
CNN_model_Improved_0001 =CNN_Class_Improved()
CNN_model_Improved_0001 = CNN_model_Improved_0001.to(device)
train_loss_lr_0001 = []
train_acc_lr_0001 = []
validate_loss_lr_0001 = []
validate_acc_lr_0001 = []
optimizer = torch.optim.Adam(CNN_model_Improved_0001.parameters(), lr=0.001)
for epoch in range(nepochs):
train_running_loss , train_running_accuracy = train(train_loader, CNN_model_Improved_0001, criterion, optimizer)
train_loss_lr_0001.append(train_running_loss)
train_acc_lr_0001.append(train_running_accuracy)
validate_running_loss , validate_running_accuracy = validate(validate_loader, CNN_model_Improved_0001, criterion, optimizer)
validate_loss_lr_0001.append(validate_running_loss)
validate_acc_lr_0001.append(validate_running_accuracy)
if validate_running_loss < best_loss:
best_loss = validate_running_loss
torch.save(CNN_model_Improved_0001.state_dict(), './CNN_model_Improved_0001.pt')
print('epoch:{},train_loss:{}, train_acc:{}, val_loss:{}, val_acc:{}'.format(epoch+1,train_running_loss,train_running_accuracy,validate_running_loss,validate_running_accuracy))
epoch:1,train_loss:3.1911271859908243, train_acc:9.544502258300781, val_loss:2.986053084218225, val_acc:15.23740291595459 epoch:2,train_loss:2.916512415959285, train_acc:16.854660034179688, val_loss:2.801635553670484, val_acc:19.28294563293457 epoch:3,train_loss:2.736940922821767, train_acc:21.859590530395508, val_loss:2.682002843812455, val_acc:23.897769927978516 epoch:4,train_loss:2.620152408554709, train_acc:24.63326072692871, val_loss:2.6289401220720867, val_acc:26.392927169799805 epoch:5,train_loss:2.5206565207983616, train_acc:27.11107063293457, val_loss:2.5242844792299493, val_acc:26.320253372192383 epoch:6,train_loss:2.450798304123286, train_acc:29.262203216552734, val_loss:2.4371402125025905, val_acc:30.365793228149414 epoch:7,train_loss:2.4043825702554376, train_acc:30.544254302978516, val_loss:2.3683269938757254, val_acc:31.855619430541992 epoch:8,train_loss:2.3393820460731463, train_acc:31.992727279663086, val_loss:2.358646592428518, val_acc:32.65503692626953 epoch:9,train_loss:2.3059581927293857, train_acc:32.68922424316406, val_loss:2.31296735586122, val_acc:34.338661193847656 epoch:10,train_loss:2.2608506270414273, train_acc:34.310279846191406, val_loss:2.29988732448844, val_acc:34.14486312866211 epoch:11,train_loss:2.2049239909155127, train_acc:35.09307098388672, val_loss:2.272709147874699, val_acc:34.39922332763672 epoch:12,train_loss:2.1789202817092987, train_acc:36.04844665527344, val_loss:2.2244559609612753, val_acc:35.804264068603516 epoch:13,train_loss:2.1175631799641446, train_acc:38.06089782714844, val_loss:2.192433759223583, val_acc:35.84060287475586 epoch:14,train_loss:2.0764113967940654, train_acc:39.78057098388672, val_loss:2.1785506315009537, val_acc:36.19186019897461 epoch:15,train_loss:2.058042284299636, train_acc:38.59097671508789, val_loss:2.1554948840030406, val_acc:38.50532913208008 epoch:16,train_loss:2.046162011355338, train_acc:39.715850830078125, val_loss:2.1306101277817127, val_acc:37.89970779418945 epoch:17,train_loss:2.001729795918662, train_acc:40.51405334472656, val_loss:2.1323553185130275, val_acc:37.89970779418945 epoch:18,train_loss:1.9814729225000687, train_acc:41.58654022216797, val_loss:2.117563998976419, val_acc:38.26308059692383 epoch:19,train_loss:1.9500723964363866, train_acc:42.42788314819336, val_loss:2.10549044054608, val_acc:39.21996307373047 epoch:20,train_loss:1.9504304912668713, train_acc:42.594303131103516, val_loss:2.1764658024144725, val_acc:37.97238540649414 epoch:21,train_loss:1.9215907302833872, train_acc:43.044254302978516, val_loss:2.087116565815238, val_acc:40.164730072021484 epoch:22,train_loss:1.907182706883673, train_acc:43.7376708984375, val_loss:2.1189985164376193, val_acc:39.35319900512695 epoch:23,train_loss:1.8836283514485557, train_acc:43.41407775878906, val_loss:2.0775403366532434, val_acc:41.085269927978516 epoch:24,train_loss:1.8656544762955616, train_acc:44.12598419189453, val_loss:2.0783199127330336, val_acc:39.80135726928711 epoch:25,train_loss:1.8698474888265486, train_acc:44.2862434387207, val_loss:2.1137201092963993, val_acc:39.24418640136719 epoch:26,train_loss:1.8349722581502248, train_acc:45.55288314819336, val_loss:2.1221803953481273, val_acc:40.261627197265625 epoch:27,train_loss:1.8137300980867013, train_acc:45.294010162353516, val_loss:2.0796976006308268, val_acc:40.007266998291016 epoch:28,train_loss:1.8014230580019528, train_acc:46.39114761352539, val_loss:2.030866586884787, val_acc:43.24127960205078 epoch:29,train_loss:1.785078847902061, train_acc:46.9612922668457, val_loss:2.035461370335069, val_acc:41.0610466003418 epoch:30,train_loss:1.7569422474979648, train_acc:47.583824157714844, val_loss:2.068865886954374, val_acc:41.569766998291016 epoch:31,train_loss:1.737341010358912, train_acc:48.243343353271484, val_loss:2.073564939720686, val_acc:40.45542526245117 epoch:32,train_loss:1.7261173259577103, train_acc:47.65779113769531, val_loss:2.045435675354891, val_acc:40.285850524902344 epoch:33,train_loss:1.7337800271412325, train_acc:47.53143310546875, val_loss:2.072408454362736, val_acc:40.34641647338867 epoch:34,train_loss:1.6981979968279777, train_acc:48.841224670410156, val_loss:2.0786357868549437, val_acc:41.315406799316406 epoch:35,train_loss:1.6865440846900264, train_acc:48.918270111083984, val_loss:2.0532170035118282, val_acc:40.92781066894531 epoch:36,train_loss:1.6882125315581553, train_acc:48.96141815185547, val_loss:2.0575836924619453, val_acc:41.35174560546875 epoch:37,train_loss:1.6592367619452393, train_acc:49.6240119934082, val_loss:2.055528163909912, val_acc:41.18217086791992 epoch:38,train_loss:1.671578735289489, train_acc:49.725711822509766, val_loss:2.0572038029515465, val_acc:41.812015533447266 epoch:39,train_loss:1.6578207319304787, train_acc:50.015411376953125, val_loss:2.097941980805508, val_acc:40.44331359863281 epoch:40,train_loss:1.6344125976223918, train_acc:50.66876220703125, val_loss:2.0761267224023507, val_acc:40.60077667236328 epoch:41,train_loss:1.6378633284709863, train_acc:50.30818176269531, val_loss:2.0407898731009904, val_acc:41.63032913208008 epoch:42,train_loss:1.6259075494912953, train_acc:50.397560119628906, val_loss:2.083180918249973, val_acc:42.405521392822266 epoch:43,train_loss:1.6108569266528068, train_acc:50.56089782714844, val_loss:2.0441354041875797, val_acc:42.708335876464844 epoch:44,train_loss:1.5851273409713655, train_acc:51.35909652709961, val_loss:2.0792733985324237, val_acc:41.42441940307617 epoch:45,train_loss:1.5764018311303043, train_acc:51.550174713134766, val_loss:2.0797172701636026, val_acc:41.89680099487305 epoch:46,train_loss:1.5714580349668243, train_acc:51.9138298034668, val_loss:2.092423987943073, val_acc:41.87257766723633 epoch:47,train_loss:1.556293245603347, train_acc:52.77983474731445, val_loss:2.0449386807375176, val_acc:42.51453399658203 epoch:48,train_loss:1.5329540262560872, train_acc:53.374629974365234, val_loss:2.0446904620458914, val_acc:42.21172332763672 epoch:49,train_loss:1.5542095405815621, train_acc:52.42233657836914, val_loss:2.0868139377860135, val_acc:41.44864273071289 epoch:50,train_loss:1.517126593364061, train_acc:53.818416595458984, val_loss:2.0363581790480505, val_acc:43.75 epoch:51,train_loss:1.5025355089345627, train_acc:54.385475158691406, val_loss:2.1164575343908267, val_acc:41.96947479248047 epoch:52,train_loss:1.528262742877712, train_acc:53.25443649291992, val_loss:2.0643693070079006, val_acc:42.41763687133789 epoch:53,train_loss:1.5132660223887517, train_acc:54.07421112060547, val_loss:2.0913036978522013, val_acc:42.151161193847656 epoch:54,train_loss:1.4974592740719135, train_acc:54.203651428222656, val_loss:2.0812370583068494, val_acc:42.562984466552734 epoch:55,train_loss:1.4779720920077442, train_acc:54.92171859741211, val_loss:2.1356683553651323, val_acc:41.99370193481445 epoch:56,train_loss:1.4970618620426697, train_acc:54.024898529052734, val_loss:2.1307524609011272, val_acc:40.964149475097656 epoch:57,train_loss:1.4814021880104697, train_acc:55.26072311401367, val_loss:2.053830152334169, val_acc:41.63032913208008 epoch:58,train_loss:1.4547195914228992, train_acc:55.48878479003906, val_loss:2.0846458839815716, val_acc:43.374515533447266 epoch:59,train_loss:1.45824292072883, train_acc:55.69526672363281, val_loss:2.080652666646381, val_acc:42.11482620239258 epoch:60,train_loss:1.4478308946423277, train_acc:55.88634490966797, val_loss:2.070925468622252, val_acc:42.65988540649414 epoch:61,train_loss:1.4465424679440153, train_acc:55.54733657836914, val_loss:2.0982873384342637, val_acc:40.625 epoch:62,train_loss:1.417205399310095, train_acc:56.055843353271484, val_loss:2.0567419307176458, val_acc:42.11482620239258 epoch:63,train_loss:1.4109138677106101, train_acc:56.74618148803711, val_loss:2.0790942436040836, val_acc:41.19428253173828 epoch:64,train_loss:1.4270178206573576, train_acc:56.21609878540039, val_loss:2.0332969399385674, val_acc:44.2344970703125 epoch:65,train_loss:1.4109003572068977, train_acc:57.066688537597656, val_loss:2.13479960519214, val_acc:39.87403106689453 epoch:66,train_loss:1.4048020120203142, train_acc:57.44267654418945, val_loss:2.026360500690549, val_acc:44.40406799316406 epoch:67,train_loss:1.4019899685707318, train_acc:57.31940460205078, val_loss:2.040835521941961, val_acc:41.7514533996582 epoch:68,train_loss:1.4054638676389435, train_acc:56.81398010253906, val_loss:2.1204093888748523, val_acc:42.38129806518555 epoch:69,train_loss:1.3860898419950136, train_acc:57.44267654418945, val_loss:2.092866908672244, val_acc:43.544090270996094 epoch:70,train_loss:1.3722516675672587, train_acc:57.73545455932617, val_loss:2.1119466715080795, val_acc:42.97480392456055 epoch:71,train_loss:1.390832188566761, train_acc:57.75394821166992, val_loss:2.059696294540583, val_acc:43.02325439453125 epoch:72,train_loss:1.3767634534976891, train_acc:57.73545455932617, val_loss:2.069774339365405, val_acc:43.47141647338867 epoch:73,train_loss:1.3583349377445921, train_acc:58.14841842651367, val_loss:2.1562981633252876, val_acc:43.386627197265625 epoch:74,train_loss:1.3378441129210432, train_acc:58.44119644165039, val_loss:2.08634912690451, val_acc:44.367733001708984 epoch:75,train_loss:1.3603977222414412, train_acc:57.89878845214844, val_loss:2.085165400837743, val_acc:42.84156799316406 epoch:76,train_loss:1.332293771427764, train_acc:58.49050521850586, val_loss:2.1116821516391844, val_acc:42.35707092285156 epoch:77,train_loss:1.3401080881350138, train_acc:58.76787185668945, val_loss:2.1042044329088787, val_acc:42.79311752319336 epoch:78,train_loss:1.3381292012316235, train_acc:58.79253005981445, val_loss:2.0878591482029405, val_acc:44.682655334472656 epoch:79,train_loss:1.3298471570014954, train_acc:59.68010330200195, val_loss:2.0947961751804796, val_acc:42.92635726928711 epoch:80,train_loss:1.3317312982660778, train_acc:59.16851806640625, val_loss:2.1222938964533253, val_acc:42.50242233276367 epoch:81,train_loss:1.3220863324650647, train_acc:59.66161346435547, val_loss:2.072736096936603, val_acc:42.50242233276367 epoch:82,train_loss:1.3219773127482488, train_acc:59.32569122314453, val_loss:2.14878735431405, val_acc:41.48497772216797 epoch:83,train_loss:1.306615230247114, train_acc:59.547584533691406, val_loss:2.13507309902546, val_acc:41.90891647338867 epoch:84,train_loss:1.3059152470537896, train_acc:59.29795455932617, val_loss:2.1891779594643173, val_acc:42.53875732421875 epoch:85,train_loss:1.2921045835201557, train_acc:59.76023483276367, val_loss:2.1152667971544488, val_acc:42.042152404785156 epoch:86,train_loss:1.2735567209283276, train_acc:60.18552780151367, val_loss:2.128107009932052, val_acc:43.2655029296875 epoch:87,train_loss:1.2963687403667608, train_acc:60.3611946105957, val_loss:2.165547340415245, val_acc:42.28439712524414 epoch:88,train_loss:1.2947157050025533, train_acc:60.23483657836914, val_loss:2.109045857606932, val_acc:43.713661193847656 epoch:89,train_loss:1.2850948854310977, train_acc:60.34270477294922, val_loss:2.1421401417532633, val_acc:43.483524322509766 epoch:90,train_loss:1.2887458261653517, train_acc:60.17628479003906, val_loss:2.1347553785457167, val_acc:41.95736312866211 epoch:91,train_loss:1.2774624993815225, train_acc:60.15470886230469, val_loss:2.111482395682224, val_acc:41.71511459350586 epoch:92,train_loss:1.258455208420048, train_acc:61.09159469604492, val_loss:2.1545954942703247, val_acc:43.21705627441406 epoch:93,train_loss:1.2406606804689713, train_acc:61.153228759765625, val_loss:2.1607651239217716, val_acc:43.21705627441406 epoch:94,train_loss:1.2692347514558826, train_acc:60.57384490966797, val_loss:2.127889716347983, val_acc:43.21705627441406 epoch:95,train_loss:1.2696357436434051, train_acc:60.5861701965332, val_loss:2.143072849096254, val_acc:42.902130126953125 epoch:96,train_loss:1.2370165519460419, train_acc:61.90520477294922, val_loss:2.181827423184417, val_acc:42.10271072387695 epoch:97,train_loss:1.25053066791162, train_acc:61.60626220703125, val_loss:2.176362212314162, val_acc:43.77422332763672 epoch:98,train_loss:1.2255728777343704, train_acc:62.244205474853516, val_loss:2.1287347937739174, val_acc:41.65455627441406 epoch:99,train_loss:1.223124155278742, train_acc:62.13325881958008, val_loss:2.218731730483299, val_acc:41.60610580444336 epoch:100,train_loss:1.2221639054061393, train_acc:61.809661865234375, val_loss:2.144239292588345, val_acc:42.47819900512695
CNN_model_Improved_00001 =CNN_Class_Improved()
CNN_model_Improved_00001 = CNN_model_Improved_00001.to(device)
train_loss_lr_00001 = []
train_acc_lr_00001 = []
validate_loss_lr_00001 = []
validate_acc_lr_00001 = []
optimizer = torch.optim.Adam(CNN_model_Improved_00001.parameters(), lr=0.0001)
nepochs=100
for epoch in range(nepochs):
train_running_loss , train_running_accuracy = train(train_loader, CNN_model_Improved_00001, criterion, optimizer)
train_loss_lr_00001.append(train_running_loss)
train_acc_lr_00001.append(train_running_accuracy)
validate_running_loss , validate_running_accuracy = validate(validate_loader, CNN_model_Improved_00001, criterion, optimizer)
validate_loss_lr_00001.append(validate_running_loss)
validate_acc_lr_00001.append(validate_running_accuracy)
if validate_running_loss < best_loss:
best_loss = validate_running_loss
torch.save(CNN_model_Improved_00001.state_dict(), './CNN_model_Improved_00001.pt')
print('epoch:{},train_loss:{}, train_acc:{}, val_loss:{}, val_acc:{}'.format(epoch+1,train_running_loss,train_running_accuracy,validate_running_loss,validate_running_accuracy))
epoch:1,train_loss:3.2796755302587206, train_acc:8.444280624389648, val_loss:3.1328728864359303, val_acc:12.051841735839844 epoch:2,train_loss:3.066263253872211, train_acc:13.929981231689453, val_loss:3.001076387804608, val_acc:15.964146614074707 epoch:3,train_loss:2.941012852290678, train_acc:17.239891052246094, val_loss:2.876550230869027, val_acc:18.265504837036133 epoch:4,train_loss:2.851026439102444, train_acc:19.125986099243164, val_loss:2.810957442882449, val_acc:20.32461166381836 epoch:5,train_loss:2.787592075280184, train_acc:20.972015380859375, val_loss:2.7556045831635942, val_acc:21.790212631225586 epoch:6,train_loss:2.7347744526947744, train_acc:22.343442916870117, val_loss:2.7041152854298436, val_acc:22.953004837036133 epoch:7,train_loss:2.6939765955569475, train_acc:23.6624755859375, val_loss:2.692073761030685, val_acc:22.347383499145508 epoch:8,train_loss:2.6639692994969835, train_acc:24.33740234375, val_loss:2.6464970333631648, val_acc:25.557170867919922 epoch:9,train_loss:2.634774837268175, train_acc:24.919872283935547, val_loss:2.6104013919830322, val_acc:26.392927169799805 epoch:10,train_loss:2.593165810996964, train_acc:26.010848999023438, val_loss:2.6000261362208876, val_acc:25.968990325927734 epoch:11,train_loss:2.569996304765961, train_acc:26.919994354248047, val_loss:2.575469815453818, val_acc:25.121124267578125 epoch:12,train_loss:2.53580813153961, train_acc:27.671966552734375, val_loss:2.5255050603733507, val_acc:28.37936019897461 epoch:13,train_loss:2.5142372176492, train_acc:27.85687828063965, val_loss:2.5231597811676734, val_acc:28.10077667236328 epoch:14,train_loss:2.484837282338791, train_acc:28.762943267822266, val_loss:2.512596291165019, val_acc:28.972869873046875 epoch:15,train_loss:2.4733378957714556, train_acc:29.570390701293945, val_loss:2.4739946099214776, val_acc:30.1719970703125 epoch:16,train_loss:2.457911544297574, train_acc:29.219058990478516, val_loss:2.4756919561430464, val_acc:29.832849502563477 epoch:17,train_loss:2.4347294315078556, train_acc:29.893985748291016, val_loss:2.4439604947733327, val_acc:30.3900203704834 epoch:18,train_loss:2.412981839575006, train_acc:31.135971069335938, val_loss:2.4274158671844837, val_acc:32.46124267578125 epoch:19,train_loss:2.3984049692661804, train_acc:31.274654388427734, val_loss:2.413873550503753, val_acc:31.782943725585938 epoch:20,train_loss:2.367776843217703, train_acc:32.33481216430664, val_loss:2.39947415507117, val_acc:31.31056022644043 epoch:21,train_loss:2.357298065219405, train_acc:31.921842575073242, val_loss:2.3900515478710798, val_acc:30.95930290222168 epoch:22,train_loss:2.333338290276612, train_acc:32.633750915527344, val_loss:2.3940046221710913, val_acc:32.30377960205078 epoch:23,train_loss:2.3160707689601288, train_acc:33.906558990478516, val_loss:2.393957775692607, val_acc:31.649709701538086 epoch:24,train_loss:2.3079099768012235, train_acc:33.56755447387695, val_loss:2.34613381984622, val_acc:32.63081359863281 epoch:25,train_loss:2.296607259462571, train_acc:33.58296585083008, val_loss:2.3192355078320173, val_acc:33.34544372558594 epoch:26,train_loss:2.2738238241545545, train_acc:34.26097106933594, val_loss:2.332012026808983, val_acc:33.406009674072266 epoch:27,train_loss:2.253186898824026, train_acc:35.20709991455078, val_loss:2.315386486607929, val_acc:33.987403869628906 epoch:28,train_loss:2.250527298662084, train_acc:35.1886100769043, val_loss:2.335046102834302, val_acc:32.945735931396484 epoch:29,train_loss:2.2227693415252414, train_acc:35.265655517578125, val_loss:2.288711567257726, val_acc:33.90261459350586 epoch:30,train_loss:2.220126310749167, train_acc:35.99913787841797, val_loss:2.3039729567461236, val_acc:34.56879806518555 epoch:31,train_loss:2.2153540913169905, train_acc:36.15323257446289, val_loss:2.293829879095388, val_acc:34.68992233276367 epoch:32,train_loss:2.1947776474190888, train_acc:36.24260330200195, val_loss:2.2992831524028334, val_acc:34.02374267578125 epoch:33,train_loss:2.177062145351658, train_acc:36.9760856628418, val_loss:2.268732553304628, val_acc:36.36143112182617 epoch:34,train_loss:2.1558406670418013, train_acc:37.090110778808594, val_loss:2.2572834325391193, val_acc:35.86482620239258 epoch:35,train_loss:2.1473963056090315, train_acc:37.60786437988281, val_loss:2.2591070408044858, val_acc:36.19186019897461 epoch:36,train_loss:2.146941250597937, train_acc:37.04080581665039, val_loss:2.2474963692731635, val_acc:35.99806213378906 epoch:37,train_loss:2.1229546514488535, train_acc:38.76972579956055, val_loss:2.2526982856351276, val_acc:36.143409729003906 epoch:38,train_loss:2.1176861496366692, train_acc:38.396820068359375, val_loss:2.22594033008398, val_acc:36.66424560546875 epoch:39,train_loss:2.1076269269694943, train_acc:38.88683319091797, val_loss:2.2588198295859403, val_acc:36.5794563293457 epoch:40,train_loss:2.108998182962632, train_acc:38.71425247192383, val_loss:2.2012144382609877, val_acc:37.294090270996094 epoch:41,train_loss:2.0898374890434672, train_acc:39.087154388427734, val_loss:2.1979586246401763, val_acc:36.48255920410156 epoch:42,train_loss:2.0901577211696014, train_acc:38.57864761352539, val_loss:2.2086341076119003, val_acc:36.02228927612305 epoch:43,train_loss:2.0680763940133993, train_acc:39.318294525146484, val_loss:2.191249215325644, val_acc:35.828487396240234 epoch:44,train_loss:2.054594759405012, train_acc:39.712772369384766, val_loss:2.1973287033480267, val_acc:36.5794563293457 epoch:45,train_loss:2.039196566011779, train_acc:40.13806915283203, val_loss:2.1795612019161847, val_acc:37.28197479248047 epoch:46,train_loss:2.0301347343173957, train_acc:40.52946090698242, val_loss:2.2055412974468496, val_acc:35.804264068603516 epoch:47,train_loss:2.014961471924415, train_acc:40.94859313964844, val_loss:2.1725045913873715, val_acc:38.674903869628906 epoch:48,train_loss:2.0152080729162907, train_acc:41.60503005981445, val_loss:2.1827113212541094, val_acc:38.468990325927734 epoch:49,train_loss:2.003793822237726, train_acc:40.79450225830078, val_loss:2.159993235455003, val_acc:39.4985466003418 epoch:50,train_loss:1.9894032076265684, train_acc:42.11045455932617, val_loss:2.171448075494101, val_acc:37.14874267578125 epoch:51,train_loss:1.970363364417172, train_acc:42.51109313964844, val_loss:2.138068362723949, val_acc:39.7044563293457 epoch:52,train_loss:1.9762740981649365, train_acc:41.777610778808594, val_loss:2.143760695013889, val_acc:38.45688247680664 epoch:53,train_loss:1.9773973748528746, train_acc:41.722137451171875, val_loss:2.1409021532812784, val_acc:38.941375732421875 epoch:54,train_loss:1.9595034891331689, train_acc:42.05806350708008, val_loss:2.1526396108228107, val_acc:38.48110580444336 epoch:55,train_loss:1.942931172410412, train_acc:43.14287567138672, val_loss:2.129172862962235, val_acc:39.60755920410156 epoch:56,train_loss:1.9418711246118037, train_acc:42.803871154785156, val_loss:2.1091779359551364, val_acc:40.007266998291016 epoch:57,train_loss:1.9373988230552899, train_acc:42.785377502441406, val_loss:2.13331177345542, val_acc:38.32364273071289 epoch:58,train_loss:1.9329443831415571, train_acc:43.19834899902344, val_loss:2.1204048672387765, val_acc:39.83769607543945 epoch:59,train_loss:1.918922458174666, train_acc:43.85786437988281, val_loss:2.1178600788116455, val_acc:39.03827667236328 epoch:60,train_loss:1.8999799595782036, train_acc:44.09208679199219, val_loss:2.1590126863745756, val_acc:38.32364273071289 epoch:61,train_loss:1.8980442722873574, train_acc:44.10441589355469, val_loss:2.1230733616407527, val_acc:39.38953399658203 epoch:62,train_loss:1.901219623328666, train_acc:43.84553909301758, val_loss:2.1353326814119207, val_acc:39.171512603759766 epoch:63,train_loss:1.8761565579465156, train_acc:44.77317428588867, val_loss:2.1243791691092557, val_acc:39.1109504699707 epoch:64,train_loss:1.8660558187044585, train_acc:45.35564422607422, val_loss:2.147592017816943, val_acc:38.31153106689453 epoch:65,train_loss:1.8696399419265386, train_acc:45.130672454833984, val_loss:2.092461014902869, val_acc:40.176841735839844 epoch:66,train_loss:1.8564818935281426, train_acc:44.95808410644531, val_loss:2.090336184168971, val_acc:40.104164123535156 epoch:67,train_loss:1.8503782375324407, train_acc:45.485084533691406, val_loss:2.0790175005447034, val_acc:40.31007766723633 epoch:68,train_loss:1.8471501198040663, train_acc:45.42344665527344, val_loss:2.120376140572304, val_acc:40.5765495300293 epoch:69,train_loss:1.8262064047819058, train_acc:46.02440643310547, val_loss:2.0704417395037273, val_acc:40.72189712524414 epoch:70,train_loss:1.8295576847516573, train_acc:46.14459991455078, val_loss:2.0638344759164853, val_acc:41.15794372558594 epoch:71,train_loss:1.8096729601628683, train_acc:46.11994552612305, val_loss:2.0690932162972384, val_acc:40.89147186279297 epoch:72,train_loss:1.8182336578707723, train_acc:46.54524230957031, val_loss:2.0709913081901017, val_acc:40.87936019897461 epoch:73,train_loss:1.8036771645912757, train_acc:46.465110778808594, val_loss:2.1412681368894355, val_acc:39.401649475097656 epoch:74,train_loss:1.7917747934894448, train_acc:46.28944778442383, val_loss:2.067728890929111, val_acc:40.5765495300293 epoch:75,train_loss:1.7835359326481113, train_acc:47.377342224121094, val_loss:2.106135518051857, val_acc:39.83769607543945 epoch:76,train_loss:1.786717737214805, train_acc:46.85959243774414, val_loss:2.0853390056033465, val_acc:41.218509674072266 epoch:77,train_loss:1.7747366879818707, train_acc:47.759490966796875, val_loss:2.0839045491329458, val_acc:41.145835876464844 epoch:78,train_loss:1.7654184201765342, train_acc:47.54376220703125, val_loss:2.0704863293226374, val_acc:41.497093200683594 epoch:79,train_loss:1.7546812049030551, train_acc:47.61464309692383, val_loss:2.071953992510951, val_acc:41.315406799316406 epoch:80,train_loss:1.7469937568585547, train_acc:47.716346740722656, val_loss:2.0678959364114804, val_acc:42.42974853515625 epoch:81,train_loss:1.7372189950660841, train_acc:48.283409118652344, val_loss:2.0594216058420582, val_acc:42.27228927612305 epoch:82,train_loss:1.7496934127525465, train_acc:48.2125244140625, val_loss:2.0404803475668265, val_acc:42.75678253173828 epoch:83,train_loss:1.7196281957908495, train_acc:48.34812545776367, val_loss:2.0440582896387856, val_acc:41.37596893310547 epoch:84,train_loss:1.720940080620128, train_acc:48.579261779785156, val_loss:2.083317277043365, val_acc:39.84980392456055 epoch:85,train_loss:1.7203708492087189, train_acc:48.83197784423828, val_loss:2.060117527495983, val_acc:42.393409729003906 epoch:86,train_loss:1.7112735199505056, train_acc:48.94292449951172, val_loss:2.0421820490859277, val_acc:42.09060287475586 epoch:87,train_loss:1.6973145960350713, train_acc:49.46067810058594, val_loss:2.0402699514876965, val_acc:42.01792526245117 epoch:88,train_loss:1.7020337757979624, train_acc:49.44218444824219, val_loss:2.08253134960352, val_acc:41.0610466003418 epoch:89,train_loss:1.6742494550682383, train_acc:50.42221450805664, val_loss:2.0706246176431344, val_acc:41.64244079589844 epoch:90,train_loss:1.676584166182569, train_acc:49.953773498535156, val_loss:2.073023798853852, val_acc:41.17005920410156 epoch:91,train_loss:1.6772573897119105, train_acc:50.32051467895508, val_loss:2.0418946632119113, val_acc:41.88468933105469 epoch:92,train_loss:1.6651838601693598, train_acc:50.431461334228516, val_loss:2.0438435743021413, val_acc:41.46075439453125 epoch:93,train_loss:1.6536222773896168, train_acc:50.400638580322266, val_loss:2.0570213101630985, val_acc:41.37596893310547 epoch:94,train_loss:1.6471364462869407, train_acc:50.714988708496094, val_loss:2.053965829139532, val_acc:41.04893112182617 epoch:95,train_loss:1.631173533800791, train_acc:51.46696472167969, val_loss:2.039872352467027, val_acc:42.24806213378906 epoch:96,train_loss:1.633132583996248, train_acc:50.59479522705078, val_loss:2.0786331587059554, val_acc:40.21317672729492 epoch:97,train_loss:1.623470737383916, train_acc:51.40532684326172, val_loss:2.0779663573863894, val_acc:41.35174560546875 epoch:98,train_loss:1.6344603582246768, train_acc:51.14645004272461, val_loss:2.018333870311116, val_acc:43.82267379760742 epoch:99,train_loss:1.6260780554551344, train_acc:51.4484748840332, val_loss:2.0569261035253836, val_acc:41.254844665527344 epoch:100,train_loss:1.6022041372998932, train_acc:51.809051513671875, val_loss:2.0680011427679728, val_acc:41.96947479248047
# Your graph
x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(2,1,figsize=(15,20),sharex=True,sharey=False)
fig.suptitle('The loss and accuracy of train and validate sets in different learning rate')
axs[0].plot(x_axis[1:],train_loss_lr_01[1:],label='train_loss at 0.1 learning rate')
axs[0].plot(x_axis,validate_loss_lr_01,label='validate_loss at 0.1 learning rate')
axs[0].plot(x_axis,train_loss_lr_0001,label='train_loss at 0.001 learning rate')
axs[0].plot(x_axis,validate_loss_lr_0001,label='validate_loss at 0.001 learning rate')
axs[0].plot(x_axis,train_loss_lr_00001,label='train_loss at 0.0001 learning rate')
axs[0].plot(x_axis,validate_loss_lr_00001,label='validate_loss at 0.0001 learning rate')
axs[1].plot(x_axis,train_acc_lr_01,label='train_accuracy at 0.1 learning rate')
axs[1].plot(x_axis,validate_acc_lr_01,label='validate_accuracy at 0.1 learning rate')
axs[1].plot(x_axis,train_acc_lr_0001,label='train_accuracy at 0.001 learning rate')
axs[1].plot(x_axis,validate_acc_lr_0001,label='validate_accuracy at 0.001 learning rate')
axs[1].plot(x_axis,train_acc_lr_00001,label='train_accuracy at 0.0001 learning rate')
axs[1].plot(x_axis,validate_acc_lr_00001,label='validate_accuracy at 0.0001 learning rate')
axs[1].set_xlabel('epoch')
axs[0].set_ylabel('loss')
axs[1].set_ylabel('percentage of accuracy')
axs[0].legend()
axs[1].legend()
plt.show()
By comparing the three cases of learning rate of 0.1, 0.001, and 0.0001, it can be seen from the figure that when the learning rate is set to 0.001, the Loss can converge better and obtain higher accuracy, so we think the model CNN_model_Improved_0001 (learning rate of 0.001) is more accurate.
# Your code here!
test_set = MyDataset("test_set")
test_loader = DataLoader(
test_set,
batch_size = 64,
shuffle = False)
Save all test predictions to a CSV file and submit it to the private class Kaggle competition.
# Your code here!
num_class = len(classes)
CNN_model_Improved_0001.load_state_dict(torch.load('./CNN_model_Improved_0001.pt'))
dic = {"Id":[],"Category":[]}
with torch.no_grad():
for data in test_loader:
images, labels = data
images = images.to(device)
outputs = CNN_model_Improved_0001(images)
labels = [l.replace("comp5625M_data_assessment_1\\test_set\\test_set\\","") for l in labels]
pre_y = torch.max(outputs, dim=1)[1].cpu().numpy()
dic["Id"].extend(labels)
dic["Category"].extend(pre_y)
df = pd.DataFrame.from_dict(dic, orient='index').T
df.to_csv("ml21zw.csv",index = False)
Fine-tuning is a way of applying or utilizing transfer learning. It is a process that takes a model that has already been trained for one task and then tunes or tweaks the model to make it perform a second similar task. You can perform finetuning in the following way:
Configuring your dataset
- Download your dataset using
torchvision.datasets.CIFAR10, explained here- Split training dataset into training and validation set similar to above. Note that the number of categories here are only 10
# Your code here!
transform = transforms.Compose(
[ transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
CIFAR10trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
length=len(CIFAR10trainset)
CIFAR10train_size,CIFAR10validate_size=int(0.8*length),int(0.2*length)
CIFAR10trainset,CIFAR10validateset=torch.utils.data.random_split(CIFAR10trainset,[CIFAR10train_size,CIFAR10validate_size],generator=torch.Generator().manual_seed(0))
print(len(CIFAR10trainset),len(CIFAR10validateset))
CIFAR10trainloader = torch.utils.data.DataLoader(CIFAR10trainset, batch_size=4,
shuffle=True, num_workers=2)
CIFAR10validateloader = torch.utils.data.DataLoader(CIFAR10validateset, batch_size=4,
shuffle=True, num_workers=2)
Files already downloaded and verified 40000 10000
Load pretrained AlexNet from PyTorch - use model copies to apply transfer learning in different configurations
# Your code here!
import torchvision.models as models
alexnet = models.alexnet(pretrained=True)
num_fc = alexnet.classifier[6].in_features
alexnet.classifier[6] = torch.nn.Linear(in_features=num_fc, out_features=10)
alexnet = alexnet.to(device)
print(alexnet)
AlexNet(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): ReLU(inplace=True)
(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): ReLU(inplace=True)
(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(inplace=True)
(8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(inplace=True)
(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
(classifier): Sequential(
(0): Dropout(p=0.5, inplace=False)
(1): Linear(in_features=9216, out_features=4096, bias=True)
(2): ReLU(inplace=True)
(3): Dropout(p=0.5, inplace=False)
(4): Linear(in_features=4096, out_features=4096, bias=True)
(5): ReLU(inplace=True)
(6): Linear(in_features=4096, out_features=10, bias=True)
)
)
Use pretrained weights from AlexNet only (on the right of figure) to initialise your model.
Configuration 1: No frozen layers
# Your model changes here - also print trainable parameters
total_params = sum(p.numel() for p in alexnet.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(
p.numel() for p in alexnet.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} trainable parameters.')
nepochs=100
optimizer = torch.optim.Adam(alexnet.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
alexnet_best_loss = 1000
alexnet_train_loss, alexnet_validate_loss, alexnet_train_accuracy, alexnet_validate_accuracy = [], [], [], []
for epoch in range(nepochs):
alexnet_train_running_loss , alexnet_train_running_accuracy = train(CIFAR10trainloader, alexnet, criterion, optimizer)
alexnet_train_loss.append(alexnet_train_running_loss)
alexnet_train_accuracy.append(alexnet_train_running_accuracy)
alexnet_validate_running_loss , alexnet_validate_running_accuracy = validate(CIFAR10validateloader, alexnet, criterion, optimizer)
alexnet_validate_loss.append(alexnet_validate_running_loss)
alexnet_validate_accuracy.append(alexnet_validate_running_accuracy)
if alexnet_validate_running_loss < alexnet_best_loss:
alexnet_best_loss = alexnet_validate_running_loss
torch.save(alexnet.state_dict(), './alexnet.pt')
print(f"epoch: {epoch+1} alexnet_train_loss: {alexnet_train_loss[epoch] : .3f} alexnet_train_accuracy: {alexnet_train_accuracy[epoch] : .3f} alexnet_validate_loss: {alexnet_validate_loss[epoch] : .3f} alexnet_validate_accuracy: {alexnet_validate_accuracy[epoch] : .3f}")
57,044,810 total parameters. 57,044,810 trainable parameters. epoch: 1 alexnet_train_loss: 0.394 alexnet_train_accuracy: 87.548 alexnet_validate_loss: 0.442 alexnet_validate_accuracy: 86.210 epoch: 2 alexnet_train_loss: 0.333 alexnet_train_accuracy: 89.198 alexnet_validate_loss: 0.464 alexnet_validate_accuracy: 85.180 epoch: 3 alexnet_train_loss: 0.308 alexnet_train_accuracy: 90.020 alexnet_validate_loss: 0.494 alexnet_validate_accuracy: 86.010 epoch: 4 alexnet_train_loss: 0.271 alexnet_train_accuracy: 91.365 alexnet_validate_loss: 0.498 alexnet_validate_accuracy: 84.900 epoch: 5 alexnet_train_loss: 0.255 alexnet_train_accuracy: 91.923 alexnet_validate_loss: 0.499 alexnet_validate_accuracy: 85.640 epoch: 6 alexnet_train_loss: 0.246 alexnet_train_accuracy: 92.272 alexnet_validate_loss: 0.531 alexnet_validate_accuracy: 85.600 epoch: 7 alexnet_train_loss: 0.238 alexnet_train_accuracy: 92.600 alexnet_validate_loss: 0.486 alexnet_validate_accuracy: 85.370 epoch: 8 alexnet_train_loss: 0.223 alexnet_train_accuracy: 93.005 alexnet_validate_loss: 0.592 alexnet_validate_accuracy: 86.620 epoch: 9 alexnet_train_loss: 0.215 alexnet_train_accuracy: 93.397 alexnet_validate_loss: 0.650 alexnet_validate_accuracy: 84.940 epoch: 10 alexnet_train_loss: 0.220 alexnet_train_accuracy: 93.335 alexnet_validate_loss: 0.453 alexnet_validate_accuracy: 86.120 epoch: 11 alexnet_train_loss: 0.211 alexnet_train_accuracy: 93.745 alexnet_validate_loss: 0.533 alexnet_validate_accuracy: 85.970 epoch: 12 alexnet_train_loss: 0.208 alexnet_train_accuracy: 93.832 alexnet_validate_loss: 0.462 alexnet_validate_accuracy: 86.240 epoch: 13 alexnet_train_loss: 0.217 alexnet_train_accuracy: 93.512 alexnet_validate_loss: 0.548 alexnet_validate_accuracy: 83.470 epoch: 14 alexnet_train_loss: 0.226 alexnet_train_accuracy: 93.488 alexnet_validate_loss: 0.528 alexnet_validate_accuracy: 83.600 epoch: 15 alexnet_train_loss: 0.228 alexnet_train_accuracy: 93.338 alexnet_validate_loss: 0.503 alexnet_validate_accuracy: 85.330 epoch: 16 alexnet_train_loss: 0.232 alexnet_train_accuracy: 93.255 alexnet_validate_loss: 0.514 alexnet_validate_accuracy: 84.990 epoch: 17 alexnet_train_loss: 0.201 alexnet_train_accuracy: 94.115 alexnet_validate_loss: 0.588 alexnet_validate_accuracy: 87.160 epoch: 18 alexnet_train_loss: 0.237 alexnet_train_accuracy: 93.277 alexnet_validate_loss: 0.554 alexnet_validate_accuracy: 85.870 epoch: 19 alexnet_train_loss: 0.220 alexnet_train_accuracy: 93.665 alexnet_validate_loss: 0.636 alexnet_validate_accuracy: 84.580 epoch: 20 alexnet_train_loss: 0.230 alexnet_train_accuracy: 93.592 alexnet_validate_loss: 0.625 alexnet_validate_accuracy: 85.690 epoch: 21 alexnet_train_loss: 0.223 alexnet_train_accuracy: 93.670 alexnet_validate_loss: 0.734 alexnet_validate_accuracy: 84.950 epoch: 22 alexnet_train_loss: 0.228 alexnet_train_accuracy: 93.442 alexnet_validate_loss: 0.587 alexnet_validate_accuracy: 86.120 epoch: 23 alexnet_train_loss: 0.236 alexnet_train_accuracy: 93.415 alexnet_validate_loss: 0.563 alexnet_validate_accuracy: 86.530 epoch: 24 alexnet_train_loss: 0.232 alexnet_train_accuracy: 93.308 alexnet_validate_loss: 0.659 alexnet_validate_accuracy: 85.700 epoch: 25 alexnet_train_loss: 0.309 alexnet_train_accuracy: 92.298 alexnet_validate_loss: 0.536 alexnet_validate_accuracy: 85.710 epoch: 26 alexnet_train_loss: 0.321 alexnet_train_accuracy: 91.825 alexnet_validate_loss: 0.566 alexnet_validate_accuracy: 84.500 epoch: 27 alexnet_train_loss: 0.272 alexnet_train_accuracy: 92.228 alexnet_validate_loss: 0.720 alexnet_validate_accuracy: 76.640 epoch: 28 alexnet_train_loss: 0.348 alexnet_train_accuracy: 90.325 alexnet_validate_loss: 0.638 alexnet_validate_accuracy: 84.810 epoch: 29 alexnet_train_loss: 0.282 alexnet_train_accuracy: 92.220 alexnet_validate_loss: 0.655 alexnet_validate_accuracy: 82.370 epoch: 30 alexnet_train_loss: 0.309 alexnet_train_accuracy: 91.463 alexnet_validate_loss: 0.658 alexnet_validate_accuracy: 83.710 epoch: 31 alexnet_train_loss: 0.280 alexnet_train_accuracy: 91.950 alexnet_validate_loss: 0.549 alexnet_validate_accuracy: 84.320 epoch: 32 alexnet_train_loss: 0.312 alexnet_train_accuracy: 91.923 alexnet_validate_loss: 0.698 alexnet_validate_accuracy: 78.330 epoch: 33 alexnet_train_loss: 0.315 alexnet_train_accuracy: 91.230 alexnet_validate_loss: 0.658 alexnet_validate_accuracy: 85.120 epoch: 34 alexnet_train_loss: 0.320 alexnet_train_accuracy: 91.110 alexnet_validate_loss: 0.799 alexnet_validate_accuracy: 74.940 epoch: 35 alexnet_train_loss: 0.340 alexnet_train_accuracy: 90.615 alexnet_validate_loss: 0.658 alexnet_validate_accuracy: 79.170 epoch: 36 alexnet_train_loss: 0.717 alexnet_train_accuracy: 87.872 alexnet_validate_loss: 0.600 alexnet_validate_accuracy: 84.700 epoch: 37 alexnet_train_loss: 0.397 alexnet_train_accuracy: 89.100 alexnet_validate_loss: 0.738 alexnet_validate_accuracy: 77.710 epoch: 38 alexnet_train_loss: 0.358 alexnet_train_accuracy: 90.077 alexnet_validate_loss: 0.858 alexnet_validate_accuracy: 73.090 epoch: 39 alexnet_train_loss: 0.488 alexnet_train_accuracy: 89.698 alexnet_validate_loss: 5.082 alexnet_validate_accuracy: 84.990 epoch: 40 alexnet_train_loss: 0.473 alexnet_train_accuracy: 86.815 alexnet_validate_loss: 0.636 alexnet_validate_accuracy: 83.720 epoch: 41 alexnet_train_loss: 0.417 alexnet_train_accuracy: 88.455 alexnet_validate_loss: 0.610 alexnet_validate_accuracy: 86.340 epoch: 42 alexnet_train_loss: 0.480 alexnet_train_accuracy: 86.640 alexnet_validate_loss: 0.767 alexnet_validate_accuracy: 79.750 epoch: 43 alexnet_train_loss: 0.414 alexnet_train_accuracy: 88.062 alexnet_validate_loss: 1.541 alexnet_validate_accuracy: 84.850 epoch: 44 alexnet_train_loss: 0.582 alexnet_train_accuracy: 87.033 alexnet_validate_loss: 0.626 alexnet_validate_accuracy: 82.700 epoch: 45 alexnet_train_loss: 0.734 alexnet_train_accuracy: 81.022 alexnet_validate_loss: 0.663 alexnet_validate_accuracy: 81.080 epoch: 46 alexnet_train_loss: 0.468 alexnet_train_accuracy: 87.082 alexnet_validate_loss: 0.882 alexnet_validate_accuracy: 72.010 epoch: 47 alexnet_train_loss: 0.507 alexnet_train_accuracy: 85.895 alexnet_validate_loss: 0.708 alexnet_validate_accuracy: 81.370 epoch: 48 alexnet_train_loss: 0.612 alexnet_train_accuracy: 82.643 alexnet_validate_loss: 0.641 alexnet_validate_accuracy: 80.320 epoch: 49 alexnet_train_loss: 0.674 alexnet_train_accuracy: 80.455 alexnet_validate_loss: 1.058 alexnet_validate_accuracy: 84.870 epoch: 50 alexnet_train_loss: 0.627 alexnet_train_accuracy: 82.395 alexnet_validate_loss: 0.913 alexnet_validate_accuracy: 73.170 epoch: 51 alexnet_train_loss: 0.686 alexnet_train_accuracy: 80.045 alexnet_validate_loss: 1.792 alexnet_validate_accuracy: 79.090 epoch: 52 alexnet_train_loss: 1.001 alexnet_train_accuracy: 76.137 alexnet_validate_loss: 0.776 alexnet_validate_accuracy: 77.090 epoch: 53 alexnet_train_loss: 0.715 alexnet_train_accuracy: 79.103 alexnet_validate_loss: 0.705 alexnet_validate_accuracy: 81.800 epoch: 54 alexnet_train_loss: 0.814 alexnet_train_accuracy: 79.415 alexnet_validate_loss: 0.665 alexnet_validate_accuracy: 80.450 epoch: 55 alexnet_train_loss: 0.749 alexnet_train_accuracy: 78.162 alexnet_validate_loss: 0.740 alexnet_validate_accuracy: 75.410 epoch: 56 alexnet_train_loss: 0.758 alexnet_train_accuracy: 78.040 alexnet_validate_loss: 0.735 alexnet_validate_accuracy: 82.220 epoch: 57 alexnet_train_loss: 0.822 alexnet_train_accuracy: 75.415 alexnet_validate_loss: 0.637 alexnet_validate_accuracy: 81.720 epoch: 58 alexnet_train_loss: 0.968 alexnet_train_accuracy: 70.317 alexnet_validate_loss: 0.664 alexnet_validate_accuracy: 79.950 epoch: 59 alexnet_train_loss: 1.215 alexnet_train_accuracy: 64.935 alexnet_validate_loss: 1.104 alexnet_validate_accuracy: 65.590 epoch: 60 alexnet_train_loss: 0.947 alexnet_train_accuracy: 71.247 alexnet_validate_loss: 3.346 alexnet_validate_accuracy: 69.950 epoch: 61 alexnet_train_loss: 0.882 alexnet_train_accuracy: 73.620 alexnet_validate_loss: 0.836 alexnet_validate_accuracy: 74.500 epoch: 62 alexnet_train_loss: 0.899 alexnet_train_accuracy: 72.747 alexnet_validate_loss: 2.586 alexnet_validate_accuracy: 63.560 epoch: 63 alexnet_train_loss: 1.309 alexnet_train_accuracy: 64.990 alexnet_validate_loss: 1.402 alexnet_validate_accuracy: 51.000 epoch: 64 alexnet_train_loss: 2.877 alexnet_train_accuracy: 55.222 alexnet_validate_loss: 0.876 alexnet_validate_accuracy: 72.660 epoch: 65 alexnet_train_loss: 1.080 alexnet_train_accuracy: 67.098 alexnet_validate_loss: 1.120 alexnet_validate_accuracy: 62.940 epoch: 66 alexnet_train_loss: 1.510 alexnet_train_accuracy: 53.173 alexnet_validate_loss: 1.054 alexnet_validate_accuracy: 64.190 epoch: 67 alexnet_train_loss: 1.904 alexnet_train_accuracy: 36.960 alexnet_validate_loss: 1.695 alexnet_validate_accuracy: 34.830 epoch: 68 alexnet_train_loss: 1.859 alexnet_train_accuracy: 35.350 alexnet_validate_loss: 1.632 alexnet_validate_accuracy: 42.080 epoch: 69 alexnet_train_loss: 2.115 alexnet_train_accuracy: 23.198 alexnet_validate_loss: 1.932 alexnet_validate_accuracy: 18.550 epoch: 70 alexnet_train_loss: 2.123 alexnet_train_accuracy: 20.605 alexnet_validate_loss: 1.955 alexnet_validate_accuracy: 18.630 epoch: 71 alexnet_train_loss: 2.148 alexnet_train_accuracy: 19.675 alexnet_validate_loss: 2.004 alexnet_validate_accuracy: 19.760 epoch: 72 alexnet_train_loss: 1.951 alexnet_train_accuracy: 20.243 alexnet_validate_loss: 1.812 alexnet_validate_accuracy: 25.590 epoch: 73 alexnet_train_loss: 2.050 alexnet_train_accuracy: 19.660 alexnet_validate_loss: 1.994 alexnet_validate_accuracy: 20.080 epoch: 74 alexnet_train_loss: 1.950 alexnet_train_accuracy: 19.325 alexnet_validate_loss: 1.802 alexnet_validate_accuracy: 23.780 epoch: 75 alexnet_train_loss: 1.971 alexnet_train_accuracy: 20.507 alexnet_validate_loss: 1.786 alexnet_validate_accuracy: 23.440 epoch: 76 alexnet_train_loss: 1.848 alexnet_train_accuracy: 22.903 alexnet_validate_loss: 1.746 alexnet_validate_accuracy: 28.540 epoch: 77 alexnet_train_loss: 1.879 alexnet_train_accuracy: 22.945 alexnet_validate_loss: 1.878 alexnet_validate_accuracy: 20.150 epoch: 78 alexnet_train_loss: 2.378 alexnet_train_accuracy: 23.980 alexnet_validate_loss: 1.720 alexnet_validate_accuracy: 31.060 epoch: 79 alexnet_train_loss: 2.087 alexnet_train_accuracy: 25.340 alexnet_validate_loss: 1.715 alexnet_validate_accuracy: 28.690 epoch: 80 alexnet_train_loss: 1.823 alexnet_train_accuracy: 25.550 alexnet_validate_loss: 1.717 alexnet_validate_accuracy: 29.760 epoch: 81 alexnet_train_loss: 1.823 alexnet_train_accuracy: 25.700 alexnet_validate_loss: 1.670 alexnet_validate_accuracy: 30.970 epoch: 82 alexnet_train_loss: 3.502 alexnet_train_accuracy: 27.030 alexnet_validate_loss: 8.780 alexnet_validate_accuracy: 26.950 epoch: 83 alexnet_train_loss: 1.810 alexnet_train_accuracy: 26.185 alexnet_validate_loss: 2.207 alexnet_validate_accuracy: 29.380 epoch: 84 alexnet_train_loss: 3.035 alexnet_train_accuracy: 28.243 alexnet_validate_loss: 1.696 alexnet_validate_accuracy: 25.830 epoch: 85 alexnet_train_loss: 1.731 alexnet_train_accuracy: 29.350 alexnet_validate_loss: 1.970 alexnet_validate_accuracy: 21.500 epoch: 86 alexnet_train_loss: 1.741 alexnet_train_accuracy: 28.772 alexnet_validate_loss: 1.680 alexnet_validate_accuracy: 28.270 epoch: 87 alexnet_train_loss: 1.708 alexnet_train_accuracy: 30.250 alexnet_validate_loss: 1.668 alexnet_validate_accuracy: 31.590 epoch: 88 alexnet_train_loss: 1.743 alexnet_train_accuracy: 30.765 alexnet_validate_loss: 1.650 alexnet_validate_accuracy: 32.550 epoch: 89 alexnet_train_loss: 1.888 alexnet_train_accuracy: 31.335 alexnet_validate_loss: 2.090 alexnet_validate_accuracy: 18.980 epoch: 90 alexnet_train_loss: 1.654 alexnet_train_accuracy: 32.897 alexnet_validate_loss: 1.596 alexnet_validate_accuracy: 34.870 epoch: 91 alexnet_train_loss: 1.676 alexnet_train_accuracy: 33.213 alexnet_validate_loss: 1.739 alexnet_validate_accuracy: 25.260 epoch: 92 alexnet_train_loss: 1.714 alexnet_train_accuracy: 31.305 alexnet_validate_loss: 1.637 alexnet_validate_accuracy: 30.600 epoch: 93 alexnet_train_loss: 2.160 alexnet_train_accuracy: 28.192 alexnet_validate_loss: 1.606 alexnet_validate_accuracy: 35.830 epoch: 94 alexnet_train_loss: 1.741 alexnet_train_accuracy: 29.843 alexnet_validate_loss: 1.731 alexnet_validate_accuracy: 29.230 epoch: 95 alexnet_train_loss: 1.984 alexnet_train_accuracy: 31.427 alexnet_validate_loss: 1.626 alexnet_validate_accuracy: 34.230 epoch: 96 alexnet_train_loss: 2.703 alexnet_train_accuracy: 30.188 alexnet_validate_loss: 1.634 alexnet_validate_accuracy: 34.240 epoch: 97 alexnet_train_loss: 1.707 alexnet_train_accuracy: 31.753 alexnet_validate_loss: 1.677 alexnet_validate_accuracy: 29.910 epoch: 98 alexnet_train_loss: 1.685 alexnet_train_accuracy: 32.290 alexnet_validate_loss: 1.681 alexnet_validate_accuracy: 30.040 epoch: 99 alexnet_train_loss: 1.785 alexnet_train_accuracy: 29.317 alexnet_validate_loss: 1.579 alexnet_validate_accuracy: 33.700 epoch: 100 alexnet_train_loss: 1.759 alexnet_train_accuracy: 30.528 alexnet_validate_loss: 1.675 alexnet_validate_accuracy: 31.120
Configuration 2: Frozen base convolution blocks
# Your changes here - also print trainable parameters
frozen_alexnet = models.alexnet(pretrained=True)
num_fc = frozen_alexnet.classifier[6].in_features
frozen_alexnet.classifier[6] = nn.Linear(in_features=num_fc, out_features=10)
frozen_alexnet = frozen_alexnet.to(device)
for param in frozen_alexnet.parameters():
param.requires_grad = False
for param in frozen_alexnet.classifier[6].parameters():
param.requires_grad = True
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(frozen_alexnet.parameters(), lr=0.001)
total_params = sum(p.numel() for p in frozen_alexnet.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(
p.numel() for p in frozen_alexnet.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} trainable parameters.')
frozen_alexnet_best_loss = 1000
frozen_alexnet_train_loss, frozen_alexnet_validate_loss, frozen_alexnet_train_accuracy, frozen_alexnet_validate_accuracy = [], [], [], []
nepochs = 100
for epoch in range(nepochs):
frozen_alexnet_train_running_loss , frozen_alexnet_train_running_accuracy = train(CIFAR10trainloader, frozen_alexnet, criterion, optimizer)
frozen_alexnet_train_loss.append(frozen_alexnet_train_running_loss)
frozen_alexnet_train_accuracy.append(frozen_alexnet_train_running_accuracy)
frozen_alexnet_validate_running_loss , frozen_alexnet_validate_running_accuracy = validate(CIFAR10validateloader, frozen_alexnet, criterion, optimizer)
frozen_alexnet_validate_loss.append(frozen_alexnet_validate_running_loss)
frozen_alexnet_validate_accuracy.append(frozen_alexnet_validate_running_accuracy)
if frozen_alexnet_validate_running_loss < frozen_alexnet_best_loss:
frozen_alexnet_best_loss = frozen_alexnet_validate_running_loss
torch.save(frozen_alexnet.state_dict(), './frozen_alexnet.pt')
print(f"epoch: {epoch+1} alexnet_train_loss: {frozen_alexnet_train_loss[epoch] : .3f} frozen_alexnet_train_accuracy: {frozen_alexnet_train_accuracy[epoch] : .3f} frozen_alexnet_validate_loss: {frozen_alexnet_validate_loss[epoch] : .3f} frozen_alexnet_validate_accuracy: {frozen_alexnet_validate_accuracy[epoch] : .3f}")
57,044,810 total parameters. 40,970 trainable parameters. epoch: 1 alexnet_train_loss: 1.221 frozen_alexnet_train_accuracy: 65.088 frozen_alexnet_validate_loss: 1.013 frozen_alexnet_validate_accuracy: 70.520 epoch: 2 alexnet_train_loss: 1.211 frozen_alexnet_train_accuracy: 68.207 frozen_alexnet_validate_loss: 0.986 frozen_alexnet_validate_accuracy: 72.110 epoch: 3 alexnet_train_loss: 1.225 frozen_alexnet_train_accuracy: 69.043 frozen_alexnet_validate_loss: 1.026 frozen_alexnet_validate_accuracy: 72.350 epoch: 4 alexnet_train_loss: 1.227 frozen_alexnet_train_accuracy: 69.412 frozen_alexnet_validate_loss: 1.055 frozen_alexnet_validate_accuracy: 71.560 epoch: 5 alexnet_train_loss: 1.233 frozen_alexnet_train_accuracy: 69.603 frozen_alexnet_validate_loss: 1.062 frozen_alexnet_validate_accuracy: 72.770 epoch: 6 alexnet_train_loss: 1.228 frozen_alexnet_train_accuracy: 70.062 frozen_alexnet_validate_loss: 1.000 frozen_alexnet_validate_accuracy: 74.160 epoch: 7 alexnet_train_loss: 1.235 frozen_alexnet_train_accuracy: 70.420 frozen_alexnet_validate_loss: 1.150 frozen_alexnet_validate_accuracy: 72.020 epoch: 8 alexnet_train_loss: 1.228 frozen_alexnet_train_accuracy: 70.717 frozen_alexnet_validate_loss: 1.062 frozen_alexnet_validate_accuracy: 73.700 epoch: 9 alexnet_train_loss: 1.255 frozen_alexnet_train_accuracy: 70.018 frozen_alexnet_validate_loss: 1.055 frozen_alexnet_validate_accuracy: 72.700 epoch: 10 alexnet_train_loss: 1.236 frozen_alexnet_train_accuracy: 70.935 frozen_alexnet_validate_loss: 1.181 frozen_alexnet_validate_accuracy: 72.630 epoch: 11 alexnet_train_loss: 1.272 frozen_alexnet_train_accuracy: 70.298 frozen_alexnet_validate_loss: 1.038 frozen_alexnet_validate_accuracy: 73.880 epoch: 12 alexnet_train_loss: 1.255 frozen_alexnet_train_accuracy: 70.700 frozen_alexnet_validate_loss: 1.059 frozen_alexnet_validate_accuracy: 73.430 epoch: 13 alexnet_train_loss: 1.253 frozen_alexnet_train_accuracy: 70.692 frozen_alexnet_validate_loss: 0.997 frozen_alexnet_validate_accuracy: 74.390 epoch: 14 alexnet_train_loss: 1.252 frozen_alexnet_train_accuracy: 71.062 frozen_alexnet_validate_loss: 0.960 frozen_alexnet_validate_accuracy: 75.530 epoch: 15 alexnet_train_loss: 1.261 frozen_alexnet_train_accuracy: 70.690 frozen_alexnet_validate_loss: 1.064 frozen_alexnet_validate_accuracy: 73.250 epoch: 16 alexnet_train_loss: 1.251 frozen_alexnet_train_accuracy: 71.050 frozen_alexnet_validate_loss: 1.081 frozen_alexnet_validate_accuracy: 73.700 epoch: 17 alexnet_train_loss: 1.249 frozen_alexnet_train_accuracy: 71.442 frozen_alexnet_validate_loss: 1.227 frozen_alexnet_validate_accuracy: 71.940 epoch: 18 alexnet_train_loss: 1.283 frozen_alexnet_train_accuracy: 70.652 frozen_alexnet_validate_loss: 0.970 frozen_alexnet_validate_accuracy: 75.880 epoch: 19 alexnet_train_loss: 1.272 frozen_alexnet_train_accuracy: 70.777 frozen_alexnet_validate_loss: 1.234 frozen_alexnet_validate_accuracy: 71.300 epoch: 20 alexnet_train_loss: 1.252 frozen_alexnet_train_accuracy: 70.905 frozen_alexnet_validate_loss: 1.022 frozen_alexnet_validate_accuracy: 74.850 epoch: 21 alexnet_train_loss: 1.263 frozen_alexnet_train_accuracy: 71.228 frozen_alexnet_validate_loss: 0.971 frozen_alexnet_validate_accuracy: 75.340 epoch: 22 alexnet_train_loss: 1.269 frozen_alexnet_train_accuracy: 71.353 frozen_alexnet_validate_loss: 1.084 frozen_alexnet_validate_accuracy: 73.930 epoch: 23 alexnet_train_loss: 1.263 frozen_alexnet_train_accuracy: 71.058 frozen_alexnet_validate_loss: 1.166 frozen_alexnet_validate_accuracy: 72.880 epoch: 24 alexnet_train_loss: 1.274 frozen_alexnet_train_accuracy: 71.213 frozen_alexnet_validate_loss: 0.962 frozen_alexnet_validate_accuracy: 75.790 epoch: 25 alexnet_train_loss: 1.268 frozen_alexnet_train_accuracy: 71.150 frozen_alexnet_validate_loss: 0.943 frozen_alexnet_validate_accuracy: 76.760 epoch: 26 alexnet_train_loss: 1.278 frozen_alexnet_train_accuracy: 71.027 frozen_alexnet_validate_loss: 1.158 frozen_alexnet_validate_accuracy: 72.650 epoch: 27 alexnet_train_loss: 1.272 frozen_alexnet_train_accuracy: 71.143 frozen_alexnet_validate_loss: 1.019 frozen_alexnet_validate_accuracy: 75.980 epoch: 28 alexnet_train_loss: 1.264 frozen_alexnet_train_accuracy: 71.363 frozen_alexnet_validate_loss: 1.013 frozen_alexnet_validate_accuracy: 75.070 epoch: 29 alexnet_train_loss: 1.276 frozen_alexnet_train_accuracy: 71.065 frozen_alexnet_validate_loss: 0.982 frozen_alexnet_validate_accuracy: 75.500 epoch: 30 alexnet_train_loss: 1.278 frozen_alexnet_train_accuracy: 71.158 frozen_alexnet_validate_loss: 1.135 frozen_alexnet_validate_accuracy: 73.710 epoch: 31 alexnet_train_loss: 1.285 frozen_alexnet_train_accuracy: 71.522 frozen_alexnet_validate_loss: 1.156 frozen_alexnet_validate_accuracy: 72.690 epoch: 32 alexnet_train_loss: 1.276 frozen_alexnet_train_accuracy: 71.247 frozen_alexnet_validate_loss: 1.176 frozen_alexnet_validate_accuracy: 73.220 epoch: 33 alexnet_train_loss: 1.283 frozen_alexnet_train_accuracy: 71.092 frozen_alexnet_validate_loss: 1.096 frozen_alexnet_validate_accuracy: 74.230 epoch: 34 alexnet_train_loss: 1.274 frozen_alexnet_train_accuracy: 71.092 frozen_alexnet_validate_loss: 0.950 frozen_alexnet_validate_accuracy: 75.980 epoch: 35 alexnet_train_loss: 1.284 frozen_alexnet_train_accuracy: 71.173 frozen_alexnet_validate_loss: 1.071 frozen_alexnet_validate_accuracy: 74.110 epoch: 36 alexnet_train_loss: 1.276 frozen_alexnet_train_accuracy: 71.402 frozen_alexnet_validate_loss: 1.087 frozen_alexnet_validate_accuracy: 74.430 epoch: 37 alexnet_train_loss: 1.270 frozen_alexnet_train_accuracy: 71.510 frozen_alexnet_validate_loss: 1.233 frozen_alexnet_validate_accuracy: 72.780 epoch: 38 alexnet_train_loss: 1.280 frozen_alexnet_train_accuracy: 71.380 frozen_alexnet_validate_loss: 0.918 frozen_alexnet_validate_accuracy: 76.990 epoch: 39 alexnet_train_loss: 1.274 frozen_alexnet_train_accuracy: 71.555 frozen_alexnet_validate_loss: 1.076 frozen_alexnet_validate_accuracy: 74.100 epoch: 40 alexnet_train_loss: 1.266 frozen_alexnet_train_accuracy: 71.777 frozen_alexnet_validate_loss: 1.021 frozen_alexnet_validate_accuracy: 75.480 epoch: 41 alexnet_train_loss: 1.289 frozen_alexnet_train_accuracy: 70.963 frozen_alexnet_validate_loss: 0.961 frozen_alexnet_validate_accuracy: 76.430 epoch: 42 alexnet_train_loss: 1.286 frozen_alexnet_train_accuracy: 71.382 frozen_alexnet_validate_loss: 1.079 frozen_alexnet_validate_accuracy: 74.010 epoch: 43 alexnet_train_loss: 1.273 frozen_alexnet_train_accuracy: 71.393 frozen_alexnet_validate_loss: 1.062 frozen_alexnet_validate_accuracy: 75.150 epoch: 44 alexnet_train_loss: 1.283 frozen_alexnet_train_accuracy: 71.365 frozen_alexnet_validate_loss: 0.995 frozen_alexnet_validate_accuracy: 75.750 epoch: 45 alexnet_train_loss: 1.270 frozen_alexnet_train_accuracy: 71.560 frozen_alexnet_validate_loss: 1.065 frozen_alexnet_validate_accuracy: 73.920 epoch: 46 alexnet_train_loss: 1.287 frozen_alexnet_train_accuracy: 71.393 frozen_alexnet_validate_loss: 1.020 frozen_alexnet_validate_accuracy: 75.250 epoch: 47 alexnet_train_loss: 1.287 frozen_alexnet_train_accuracy: 71.507 frozen_alexnet_validate_loss: 1.338 frozen_alexnet_validate_accuracy: 71.470 epoch: 48 alexnet_train_loss: 1.291 frozen_alexnet_train_accuracy: 71.433 frozen_alexnet_validate_loss: 0.999 frozen_alexnet_validate_accuracy: 75.370 epoch: 49 alexnet_train_loss: 1.277 frozen_alexnet_train_accuracy: 71.548 frozen_alexnet_validate_loss: 1.393 frozen_alexnet_validate_accuracy: 70.220 epoch: 50 alexnet_train_loss: 1.282 frozen_alexnet_train_accuracy: 71.548 frozen_alexnet_validate_loss: 0.990 frozen_alexnet_validate_accuracy: 75.620 epoch: 51 alexnet_train_loss: 1.288 frozen_alexnet_train_accuracy: 71.418 frozen_alexnet_validate_loss: 0.969 frozen_alexnet_validate_accuracy: 75.800 epoch: 52 alexnet_train_loss: 1.279 frozen_alexnet_train_accuracy: 71.338 frozen_alexnet_validate_loss: 0.987 frozen_alexnet_validate_accuracy: 76.050 epoch: 53 alexnet_train_loss: 1.289 frozen_alexnet_train_accuracy: 71.442 frozen_alexnet_validate_loss: 1.145 frozen_alexnet_validate_accuracy: 73.350 epoch: 54 alexnet_train_loss: 1.277 frozen_alexnet_train_accuracy: 71.610 frozen_alexnet_validate_loss: 1.040 frozen_alexnet_validate_accuracy: 74.970 epoch: 55 alexnet_train_loss: 1.273 frozen_alexnet_train_accuracy: 71.495 frozen_alexnet_validate_loss: 1.073 frozen_alexnet_validate_accuracy: 73.670 epoch: 56 alexnet_train_loss: 1.296 frozen_alexnet_train_accuracy: 71.390 frozen_alexnet_validate_loss: 1.147 frozen_alexnet_validate_accuracy: 73.320 epoch: 57 alexnet_train_loss: 1.287 frozen_alexnet_train_accuracy: 71.395 frozen_alexnet_validate_loss: 1.131 frozen_alexnet_validate_accuracy: 72.870 epoch: 58 alexnet_train_loss: 1.285 frozen_alexnet_train_accuracy: 71.520 frozen_alexnet_validate_loss: 1.129 frozen_alexnet_validate_accuracy: 73.050 epoch: 59 alexnet_train_loss: 1.295 frozen_alexnet_train_accuracy: 71.257 frozen_alexnet_validate_loss: 1.079 frozen_alexnet_validate_accuracy: 73.970 epoch: 60 alexnet_train_loss: 1.296 frozen_alexnet_train_accuracy: 71.430 frozen_alexnet_validate_loss: 1.073 frozen_alexnet_validate_accuracy: 74.150 epoch: 61 alexnet_train_loss: 1.285 frozen_alexnet_train_accuracy: 71.625 frozen_alexnet_validate_loss: 0.973 frozen_alexnet_validate_accuracy: 75.700 epoch: 62 alexnet_train_loss: 1.275 frozen_alexnet_train_accuracy: 71.452 frozen_alexnet_validate_loss: 1.129 frozen_alexnet_validate_accuracy: 73.560 epoch: 63 alexnet_train_loss: 1.294 frozen_alexnet_train_accuracy: 71.283 frozen_alexnet_validate_loss: 1.064 frozen_alexnet_validate_accuracy: 74.510 epoch: 64 alexnet_train_loss: 1.302 frozen_alexnet_train_accuracy: 71.350 frozen_alexnet_validate_loss: 1.124 frozen_alexnet_validate_accuracy: 72.980 epoch: 65 alexnet_train_loss: 1.287 frozen_alexnet_train_accuracy: 71.228 frozen_alexnet_validate_loss: 1.002 frozen_alexnet_validate_accuracy: 75.530 epoch: 66 alexnet_train_loss: 1.289 frozen_alexnet_train_accuracy: 71.675 frozen_alexnet_validate_loss: 1.023 frozen_alexnet_validate_accuracy: 75.140 epoch: 67 alexnet_train_loss: 1.302 frozen_alexnet_train_accuracy: 71.180 frozen_alexnet_validate_loss: 0.964 frozen_alexnet_validate_accuracy: 76.340 epoch: 68 alexnet_train_loss: 1.290 frozen_alexnet_train_accuracy: 71.615 frozen_alexnet_validate_loss: 1.032 frozen_alexnet_validate_accuracy: 74.880 epoch: 69 alexnet_train_loss: 1.299 frozen_alexnet_train_accuracy: 71.375 frozen_alexnet_validate_loss: 1.012 frozen_alexnet_validate_accuracy: 75.560 epoch: 70 alexnet_train_loss: 1.294 frozen_alexnet_train_accuracy: 71.412 frozen_alexnet_validate_loss: 0.999 frozen_alexnet_validate_accuracy: 76.100 epoch: 71 alexnet_train_loss: 1.299 frozen_alexnet_train_accuracy: 71.085 frozen_alexnet_validate_loss: 1.129 frozen_alexnet_validate_accuracy: 73.600 epoch: 72 alexnet_train_loss: 1.290 frozen_alexnet_train_accuracy: 71.332 frozen_alexnet_validate_loss: 1.012 frozen_alexnet_validate_accuracy: 74.650 epoch: 73 alexnet_train_loss: 1.286 frozen_alexnet_train_accuracy: 71.215 frozen_alexnet_validate_loss: 1.049 frozen_alexnet_validate_accuracy: 75.230 epoch: 74 alexnet_train_loss: 1.284 frozen_alexnet_train_accuracy: 71.660 frozen_alexnet_validate_loss: 0.970 frozen_alexnet_validate_accuracy: 75.570 epoch: 75 alexnet_train_loss: 1.280 frozen_alexnet_train_accuracy: 71.757 frozen_alexnet_validate_loss: 1.088 frozen_alexnet_validate_accuracy: 74.000 epoch: 76 alexnet_train_loss: 1.294 frozen_alexnet_train_accuracy: 71.357 frozen_alexnet_validate_loss: 1.112 frozen_alexnet_validate_accuracy: 73.980 epoch: 77 alexnet_train_loss: 1.297 frozen_alexnet_train_accuracy: 71.327 frozen_alexnet_validate_loss: 1.116 frozen_alexnet_validate_accuracy: 73.670 epoch: 78 alexnet_train_loss: 1.304 frozen_alexnet_train_accuracy: 71.510 frozen_alexnet_validate_loss: 1.069 frozen_alexnet_validate_accuracy: 74.740 epoch: 79 alexnet_train_loss: 1.295 frozen_alexnet_train_accuracy: 71.543 frozen_alexnet_validate_loss: 1.163 frozen_alexnet_validate_accuracy: 73.060 epoch: 80 alexnet_train_loss: 1.282 frozen_alexnet_train_accuracy: 71.673 frozen_alexnet_validate_loss: 1.058 frozen_alexnet_validate_accuracy: 75.470 epoch: 81 alexnet_train_loss: 1.303 frozen_alexnet_train_accuracy: 71.605 frozen_alexnet_validate_loss: 1.134 frozen_alexnet_validate_accuracy: 73.480 epoch: 82 alexnet_train_loss: 1.285 frozen_alexnet_train_accuracy: 71.640 frozen_alexnet_validate_loss: 1.027 frozen_alexnet_validate_accuracy: 75.220 epoch: 83 alexnet_train_loss: 1.295 frozen_alexnet_train_accuracy: 71.637 frozen_alexnet_validate_loss: 1.229 frozen_alexnet_validate_accuracy: 72.680 epoch: 84 alexnet_train_loss: 1.280 frozen_alexnet_train_accuracy: 71.610 frozen_alexnet_validate_loss: 1.176 frozen_alexnet_validate_accuracy: 73.700 epoch: 85 alexnet_train_loss: 1.297 frozen_alexnet_train_accuracy: 71.115 frozen_alexnet_validate_loss: 1.343 frozen_alexnet_validate_accuracy: 71.130 epoch: 86 alexnet_train_loss: 1.292 frozen_alexnet_train_accuracy: 71.495 frozen_alexnet_validate_loss: 1.181 frozen_alexnet_validate_accuracy: 72.380 epoch: 87 alexnet_train_loss: 1.278 frozen_alexnet_train_accuracy: 71.567 frozen_alexnet_validate_loss: 1.094 frozen_alexnet_validate_accuracy: 73.420 epoch: 88 alexnet_train_loss: 1.293 frozen_alexnet_train_accuracy: 71.357 frozen_alexnet_validate_loss: 1.094 frozen_alexnet_validate_accuracy: 74.610 epoch: 89 alexnet_train_loss: 1.308 frozen_alexnet_train_accuracy: 71.317 frozen_alexnet_validate_loss: 1.055 frozen_alexnet_validate_accuracy: 74.620 epoch: 90 alexnet_train_loss: 1.303 frozen_alexnet_train_accuracy: 71.408 frozen_alexnet_validate_loss: 1.019 frozen_alexnet_validate_accuracy: 76.000 epoch: 91 alexnet_train_loss: 1.286 frozen_alexnet_train_accuracy: 71.310 frozen_alexnet_validate_loss: 1.136 frozen_alexnet_validate_accuracy: 73.380 epoch: 92 alexnet_train_loss: 1.282 frozen_alexnet_train_accuracy: 71.622 frozen_alexnet_validate_loss: 1.216 frozen_alexnet_validate_accuracy: 72.250 epoch: 93 alexnet_train_loss: 1.294 frozen_alexnet_train_accuracy: 71.545 frozen_alexnet_validate_loss: 1.079 frozen_alexnet_validate_accuracy: 74.420 epoch: 94 alexnet_train_loss: 1.306 frozen_alexnet_train_accuracy: 71.435 frozen_alexnet_validate_loss: 0.997 frozen_alexnet_validate_accuracy: 75.490 epoch: 95 alexnet_train_loss: 1.282 frozen_alexnet_train_accuracy: 71.442 frozen_alexnet_validate_loss: 1.135 frozen_alexnet_validate_accuracy: 73.330 epoch: 96 alexnet_train_loss: 1.306 frozen_alexnet_train_accuracy: 71.442 frozen_alexnet_validate_loss: 1.148 frozen_alexnet_validate_accuracy: 72.790 epoch: 97 alexnet_train_loss: 1.292 frozen_alexnet_train_accuracy: 71.410 frozen_alexnet_validate_loss: 0.954 frozen_alexnet_validate_accuracy: 76.570 epoch: 98 alexnet_train_loss: 1.301 frozen_alexnet_train_accuracy: 71.357 frozen_alexnet_validate_loss: 1.255 frozen_alexnet_validate_accuracy: 71.920 epoch: 99 alexnet_train_loss: 1.271 frozen_alexnet_train_accuracy: 72.003 frozen_alexnet_validate_loss: 1.179 frozen_alexnet_validate_accuracy: 72.660 epoch: 100 alexnet_train_loss: 1.304 frozen_alexnet_train_accuracy: 71.537 frozen_alexnet_validate_loss: 1.168 frozen_alexnet_validate_accuracy: 72.030
# Your graphs here and please provide comment in markdown in another cell
x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(2,2,figsize=(15,20),sharex=True,sharey=False)
fig.suptitle('The loss and accuracy of train and validate sets with and without frozen')
axs[0][0].plot(x_axis,alexnet_train_loss,label='alexnet train_loss')
axs[0][0].plot(x_axis,frozen_alexnet_train_loss,label='frozen_alexnet train_loss')
axs[0][1].plot(x_axis,alexnet_validate_loss,label='alexnet validate_loss')
axs[0][1].plot(x_axis,frozen_alexnet_validate_loss,label='frozen_alexnet validate_loss')
axs[1][0].plot(x_axis,alexnet_train_accuracy,label='alexnet train_accuracy')
axs[1][0].plot(x_axis,frozen_alexnet_train_accuracy,label='frozen_alexnet train_accuracy')
axs[1][1].plot(x_axis,alexnet_validate_accuracy,label='alexnet validate_accuracy')
axs[1][1].plot(x_axis,frozen_alexnet_validate_accuracy,label='frozen_alexnet validate_accuracy')
axs[1][0].set_xlabel('epoch')
axs[1][1].set_xlabel('epoch')
axs[0][0].set_ylabel('loss')
axs[0][1].set_ylabel('loss')
axs[1][0].set_ylabel('percentage of accuracy')
axs[1][1].set_ylabel('percentage of accuracy')
axs[0][0].legend()
axs[0][1].legend()
axs[1][1].legend()
axs[1][0].legend()
plt.show()
If the convolution layer is not frozen and pre-trained weights are imported. The weights of the original AlexNet network are overwritten by the new data after the training epoch rises causing a severe bias, hence a steep drop in Accuracy and a steep rise and fluctuation in Loss.
We often need to compare our model with other state-of-the-art methods to understand how well it performs compared to existing architectures. Here you will thus compare your model design with AlexNet on the TinyImageNet30 dataset
Load AlexNet as you did above
Train AlexNet on TinyImageNet30 dataset until convergence. Make sure you use the same dataset
# Your code here!
alexnet_compare = models.alexnet(pretrained=True)
com_num_fc = alexnet_compare.classifier[6].in_features
alexnet_compare.classifier[6] = torch.nn.Linear(in_features=com_num_fc, out_features=30)
alexnet_compare = alexnet_compare.to(device)
print(alexnet_compare )
AlexNet(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): ReLU(inplace=True)
(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): ReLU(inplace=True)
(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(inplace=True)
(8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(inplace=True)
(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
(classifier): Sequential(
(0): Dropout(p=0.5, inplace=False)
(1): Linear(in_features=9216, out_features=4096, bias=True)
(2): ReLU(inplace=True)
(3): Dropout(p=0.5, inplace=False)
(4): Linear(in_features=4096, out_features=4096, bias=True)
(5): ReLU(inplace=True)
(6): Linear(in_features=4096, out_features=30, bias=True)
)
)
data_augmentation_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomRotation((-20,20)),
transforms.ColorJitter(hue=0.2, saturation=0.2, brightness=0.2),
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
train_set = MyDataset("train_set",transform=data_augmentation_transform)
length=len(train_set)
train_size,validate_size=int(0.8*length),int(0.2*length)
train_set,validate_set=torch.utils.data.random_split(train_set,[train_size,validate_size],generator=torch.Generator().manual_seed(0))
train_loader = DataLoader(
train_set,
batch_size = 64,
shuffle = True)
validate_loader = DataLoader(
validate_set,
batch_size = 64,
shuffle = True)
# Your code here!
alexnet_start = time.time()
nepochs = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(alexnet_compare.parameters(), 0.0005)
alexnet_compare_best_loss = 1000.0
alexnet_compare_train_loss, alexnet_compare_train_accuracy= [], []
alexnet_compare_validate_loss, alexnet_compare_validate_accuracy= [], []
for epoch in range(nepochs):
alexnet_compare_train_running_loss , alexnet_compare_train_running_accuracy = train(train_loader, alexnet_compare, criterion, optimizer)
alexnet_compare_train_loss.append(alexnet_compare_train_running_loss)
alexnet_compare_train_accuracy.append(alexnet_compare_train_running_accuracy)
alexnet_compare_validate_running_loss , alexnet_compare_validate_running_accuracy = validate(validate_loader, alexnet_compare, criterion, optimizer)
alexnet_compare_validate_loss.append(alexnet_compare_validate_running_loss)
alexnet_compare_validate_accuracy.append(alexnet_compare_validate_running_accuracy)
if alexnet_compare_validate_running_loss < alexnet_compare_best_loss:
alexnet_compare_best_loss = alexnet_compare_validate_running_loss
torch.save(alexnet_compare.state_dict(), './alexnet_compare.pt')
print(f"epoch: {epoch+1} train_loss: {alexnet_compare_train_running_loss : .3f} train_accuracy: {alexnet_compare_train_running_accuracy : .3f} validate_loss: {alexnet_compare_validate_running_loss : .3f} validate_accuracy: {alexnet_compare_validate_running_accuracy : .3f}")
alexnet_end = time.time()
alexnet_running_time = alexnet_end - alexnet_start
epoch: 1 train_loss: 0.998 train_accuracy: 70.368 validate_loss: 1.710 validate_accuracy: 55.136 epoch: 2 train_loss: 0.993 train_accuracy: 70.602 validate_loss: 1.614 validate_accuracy: 57.679 epoch: 3 train_loss: 0.963 train_accuracy: 71.240 validate_loss: 1.662 validate_accuracy: 56.020 epoch: 4 train_loss: 0.999 train_accuracy: 70.648 validate_loss: 1.652 validate_accuracy: 56.565 epoch: 5 train_loss: 1.011 train_accuracy: 70.115 validate_loss: 1.563 validate_accuracy: 56.904 epoch: 6 train_loss: 0.963 train_accuracy: 71.320 validate_loss: 1.634 validate_accuracy: 56.880 epoch: 7 train_loss: 0.958 train_accuracy: 71.083 validate_loss: 1.595 validate_accuracy: 56.795 epoch: 8 train_loss: 0.950 train_accuracy: 71.893 validate_loss: 1.600 validate_accuracy: 56.831 epoch: 9 train_loss: 0.943 train_accuracy: 71.650 validate_loss: 1.632 validate_accuracy: 57.376 epoch: 10 train_loss: 0.946 train_accuracy: 71.946 validate_loss: 1.660 validate_accuracy: 55.729 epoch: 11 train_loss: 0.940 train_accuracy: 71.875 validate_loss: 1.634 validate_accuracy: 55.984 epoch: 12 train_loss: 0.922 train_accuracy: 72.482 validate_loss: 1.561 validate_accuracy: 57.219 epoch: 13 train_loss: 0.940 train_accuracy: 71.761 validate_loss: 1.674 validate_accuracy: 55.293 epoch: 14 train_loss: 0.961 train_accuracy: 71.320 validate_loss: 1.668 validate_accuracy: 55.608 epoch: 15 train_loss: 0.968 train_accuracy: 71.003 validate_loss: 1.598 validate_accuracy: 56.468 epoch: 16 train_loss: 0.919 train_accuracy: 71.912 validate_loss: 1.666 validate_accuracy: 55.608 epoch: 17 train_loss: 0.950 train_accuracy: 72.149 validate_loss: 1.644 validate_accuracy: 54.603 epoch: 18 train_loss: 0.970 train_accuracy: 71.071 validate_loss: 1.620 validate_accuracy: 56.953 epoch: 19 train_loss: 0.900 train_accuracy: 72.673 validate_loss: 1.652 validate_accuracy: 56.383 epoch: 20 train_loss: 0.937 train_accuracy: 71.712 validate_loss: 1.646 validate_accuracy: 56.068 epoch: 21 train_loss: 0.943 train_accuracy: 72.343 validate_loss: 1.673 validate_accuracy: 55.402 epoch: 22 train_loss: 0.944 train_accuracy: 71.739 validate_loss: 1.736 validate_accuracy: 56.044 epoch: 23 train_loss: 0.945 train_accuracy: 71.869 validate_loss: 1.681 validate_accuracy: 55.414 epoch: 24 train_loss: 0.911 train_accuracy: 72.698 validate_loss: 1.609 validate_accuracy: 57.607 epoch: 25 train_loss: 0.949 train_accuracy: 71.505 validate_loss: 1.745 validate_accuracy: 54.082 epoch: 26 train_loss: 0.931 train_accuracy: 71.992 validate_loss: 1.728 validate_accuracy: 54.566 epoch: 27 train_loss: 0.886 train_accuracy: 73.552 validate_loss: 1.671 validate_accuracy: 55.705 epoch: 28 train_loss: 0.939 train_accuracy: 72.467 validate_loss: 1.692 validate_accuracy: 56.214 epoch: 29 train_loss: 0.900 train_accuracy: 72.840 validate_loss: 1.671 validate_accuracy: 54.797 epoch: 30 train_loss: 0.908 train_accuracy: 73.018 validate_loss: 1.673 validate_accuracy: 55.366 epoch: 31 train_loss: 0.914 train_accuracy: 73.277 validate_loss: 1.584 validate_accuracy: 56.844 epoch: 32 train_loss: 0.923 train_accuracy: 72.020 validate_loss: 1.683 validate_accuracy: 55.584 epoch: 33 train_loss: 0.885 train_accuracy: 73.444 validate_loss: 1.708 validate_accuracy: 56.407 epoch: 34 train_loss: 0.870 train_accuracy: 73.632 validate_loss: 1.626 validate_accuracy: 56.202 epoch: 35 train_loss: 0.897 train_accuracy: 73.391 validate_loss: 1.676 validate_accuracy: 55.208 epoch: 36 train_loss: 0.915 train_accuracy: 73.074 validate_loss: 1.705 validate_accuracy: 53.949 epoch: 37 train_loss: 0.931 train_accuracy: 72.759 validate_loss: 1.755 validate_accuracy: 54.542 epoch: 38 train_loss: 0.893 train_accuracy: 73.422 validate_loss: 1.773 validate_accuracy: 56.456 epoch: 39 train_loss: 0.895 train_accuracy: 73.083 validate_loss: 1.697 validate_accuracy: 56.202 epoch: 40 train_loss: 0.893 train_accuracy: 72.926 validate_loss: 1.743 validate_accuracy: 56.056 epoch: 41 train_loss: 0.856 train_accuracy: 73.980 validate_loss: 1.738 validate_accuracy: 55.426 epoch: 42 train_loss: 0.881 train_accuracy: 73.968 validate_loss: 1.749 validate_accuracy: 55.233 epoch: 43 train_loss: 0.906 train_accuracy: 73.213 validate_loss: 1.688 validate_accuracy: 55.935 epoch: 44 train_loss: 0.906 train_accuracy: 73.323 validate_loss: 1.706 validate_accuracy: 56.177 epoch: 45 train_loss: 0.911 train_accuracy: 73.240 validate_loss: 1.714 validate_accuracy: 55.402 epoch: 46 train_loss: 0.921 train_accuracy: 72.673 validate_loss: 1.751 validate_accuracy: 53.634 epoch: 47 train_loss: 0.854 train_accuracy: 74.353 validate_loss: 1.711 validate_accuracy: 57.025 epoch: 48 train_loss: 0.919 train_accuracy: 73.077 validate_loss: 1.722 validate_accuracy: 54.966 epoch: 49 train_loss: 0.897 train_accuracy: 73.345 validate_loss: 1.759 validate_accuracy: 55.390 epoch: 50 train_loss: 0.887 train_accuracy: 73.780 validate_loss: 1.702 validate_accuracy: 56.444 epoch: 51 train_loss: 0.864 train_accuracy: 74.270 validate_loss: 1.694 validate_accuracy: 56.940 epoch: 52 train_loss: 0.863 train_accuracy: 74.458 validate_loss: 1.749 validate_accuracy: 54.203 epoch: 53 train_loss: 0.879 train_accuracy: 74.060 validate_loss: 1.749 validate_accuracy: 53.973 epoch: 54 train_loss: 0.866 train_accuracy: 73.937 validate_loss: 1.733 validate_accuracy: 54.336 epoch: 55 train_loss: 0.897 train_accuracy: 73.555 validate_loss: 1.700 validate_accuracy: 55.051 epoch: 56 train_loss: 0.920 train_accuracy: 72.602 validate_loss: 1.690 validate_accuracy: 55.463 epoch: 57 train_loss: 0.850 train_accuracy: 75.099 validate_loss: 1.699 validate_accuracy: 56.953 epoch: 58 train_loss: 0.867 train_accuracy: 74.011 validate_loss: 1.647 validate_accuracy: 55.959 epoch: 59 train_loss: 0.839 train_accuracy: 75.398 validate_loss: 1.748 validate_accuracy: 56.989 epoch: 60 train_loss: 0.854 train_accuracy: 74.726 validate_loss: 1.865 validate_accuracy: 55.632 epoch: 61 train_loss: 0.899 train_accuracy: 74.005 validate_loss: 1.741 validate_accuracy: 56.008 epoch: 62 train_loss: 0.846 train_accuracy: 74.661 validate_loss: 1.751 validate_accuracy: 57.437 epoch: 63 train_loss: 0.884 train_accuracy: 74.140 validate_loss: 1.741 validate_accuracy: 55.402 epoch: 64 train_loss: 0.822 train_accuracy: 75.515 validate_loss: 1.789 validate_accuracy: 55.838 epoch: 65 train_loss: 0.872 train_accuracy: 74.670 validate_loss: 1.762 validate_accuracy: 54.494 epoch: 66 train_loss: 0.908 train_accuracy: 73.576 validate_loss: 1.690 validate_accuracy: 55.620 epoch: 67 train_loss: 0.861 train_accuracy: 74.686 validate_loss: 1.708 validate_accuracy: 56.189 epoch: 68 train_loss: 0.861 train_accuracy: 74.593 validate_loss: 1.739 validate_accuracy: 55.826 epoch: 69 train_loss: 0.884 train_accuracy: 74.057 validate_loss: 1.721 validate_accuracy: 54.845 epoch: 70 train_loss: 0.865 train_accuracy: 74.775 validate_loss: 1.813 validate_accuracy: 54.760 epoch: 71 train_loss: 0.896 train_accuracy: 73.755 validate_loss: 1.762 validate_accuracy: 53.755 epoch: 72 train_loss: 0.832 train_accuracy: 74.951 validate_loss: 1.748 validate_accuracy: 55.196 epoch: 73 train_loss: 0.842 train_accuracy: 75.034 validate_loss: 1.670 validate_accuracy: 55.947 epoch: 74 train_loss: 0.889 train_accuracy: 73.986 validate_loss: 1.761 validate_accuracy: 55.911 epoch: 75 train_loss: 0.861 train_accuracy: 74.596 validate_loss: 1.804 validate_accuracy: 54.881 epoch: 76 train_loss: 0.862 train_accuracy: 75.148 validate_loss: 1.761 validate_accuracy: 54.142 epoch: 77 train_loss: 0.833 train_accuracy: 75.693 validate_loss: 1.736 validate_accuracy: 55.329 epoch: 78 train_loss: 0.887 train_accuracy: 74.063 validate_loss: 1.694 validate_accuracy: 55.305 epoch: 79 train_loss: 0.822 train_accuracy: 75.324 validate_loss: 1.733 validate_accuracy: 56.274 epoch: 80 train_loss: 0.828 train_accuracy: 75.884 validate_loss: 1.680 validate_accuracy: 56.177 epoch: 81 train_loss: 0.839 train_accuracy: 75.173 validate_loss: 1.690 validate_accuracy: 55.947 epoch: 82 train_loss: 0.860 train_accuracy: 74.609 validate_loss: 1.784 validate_accuracy: 55.463 epoch: 83 train_loss: 0.846 train_accuracy: 75.247 validate_loss: 1.783 validate_accuracy: 57.122 epoch: 84 train_loss: 0.852 train_accuracy: 75.083 validate_loss: 1.808 validate_accuracy: 55.669 epoch: 85 train_loss: 0.845 train_accuracy: 75.339 validate_loss: 1.755 validate_accuracy: 56.347 epoch: 86 train_loss: 0.832 train_accuracy: 75.592 validate_loss: 1.766 validate_accuracy: 54.784 epoch: 87 train_loss: 0.850 train_accuracy: 75.592 validate_loss: 1.813 validate_accuracy: 53.791 epoch: 88 train_loss: 0.872 train_accuracy: 74.528 validate_loss: 1.943 validate_accuracy: 53.719 epoch: 89 train_loss: 0.848 train_accuracy: 75.151 validate_loss: 1.756 validate_accuracy: 56.904 epoch: 90 train_loss: 0.914 train_accuracy: 73.567 validate_loss: 1.770 validate_accuracy: 54.700 epoch: 91 train_loss: 0.877 train_accuracy: 74.491 validate_loss: 1.850 validate_accuracy: 53.852 epoch: 92 train_loss: 0.883 train_accuracy: 74.935 validate_loss: 1.777 validate_accuracy: 55.814 epoch: 93 train_loss: 0.846 train_accuracy: 75.542 validate_loss: 1.776 validate_accuracy: 56.202 epoch: 94 train_loss: 0.866 train_accuracy: 74.732 validate_loss: 1.743 validate_accuracy: 56.141 epoch: 95 train_loss: 0.823 train_accuracy: 75.906 validate_loss: 1.854 validate_accuracy: 55.596 epoch: 96 train_loss: 0.842 train_accuracy: 75.413 validate_loss: 1.725 validate_accuracy: 55.281 epoch: 97 train_loss: 0.839 train_accuracy: 75.881 validate_loss: 1.809 validate_accuracy: 55.632 epoch: 98 train_loss: 0.832 train_accuracy: 75.977 validate_loss: 1.862 validate_accuracy: 54.227 epoch: 99 train_loss: 0.835 train_accuracy: 75.675 validate_loss: 1.756 validate_accuracy: 55.245 epoch: 100 train_loss: 0.844 train_accuracy: 75.370 validate_loss: 1.864 validate_accuracy: 54.554
class CNN_compared(nn.Module):
def __init__(self):
super(CNN_compared,self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.flc1 = nn.Linear(64*28*28,1024)
self.dropout = nn.Dropout(p=0.3)
self.flc2 = nn.Linear(1024,30)
def forward(self,x):
x = self.maxpool1(nn.functional.relu(self.conv1(x)))
x = self.maxpool2(nn.functional.relu(self.conv2(x)))
x = self.maxpool3(nn.functional.relu(self.conv3(x)))
x = x.view(-1,64*28*28)
x = self.dropout(x)
x = nn.functional.relu(self.flc1(x))
x = self.flc2(x)
return x
CNN_compare_model = CNN_compared()
CNN_compare_model = CNN_compare_model.to(device)
# Your code here!
cnn_start = time.time()
nepochs = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(CNN_compare_model.parameters(), 0.001)
CNN_compare_best_loss = 1000.0
CNN_compare_train_loss, CNN_compare_train_accuracy= [], []
CNN_compare_validate_loss, CNN_compare_validate_accuracy= [], []
for epoch in range(nepochs):
CNN_compare_train_running_loss , CNN_compare_train_running_accuracy = train(train_loader, CNN_compare_model, criterion, optimizer)
CNN_compare_train_loss.append(CNN_compare_train_running_loss)
CNN_compare_train_accuracy.append(CNN_compare_train_running_accuracy)
CNN_compare_validate_running_loss , CNN_compare_validate_running_accuracy = validate(validate_loader, CNN_compare_model, criterion, optimizer)
CNN_compare_validate_loss.append(CNN_compare_validate_running_loss)
CNN_compare_validate_accuracy.append(CNN_compare_validate_running_accuracy)
if CNN_compare_validate_running_loss < CNN_compare_best_loss:
CNN_compare_best_loss = CNN_compare_validate_running_loss
torch.save(CNN_compare_model.state_dict(), './cnn_compare.pt')
print(f"epoch: {epoch+1} train_loss: {CNN_compare_train_running_loss : .3f} train_accuracy: {CNN_compare_train_running_accuracy : .3f} validate_loss: {CNN_compare_validate_running_loss : .3f} validate_accuracy: {CNN_compare_validate_running_accuracy : .3f}")
cnn_end = time.time()
cnn_running_time = cnn_end - cnn_start
epoch: 1 train_loss: 3.293 train_accuracy: 9.159 validate_loss: 3.021 validate_accuracy: 13.118 epoch: 2 train_loss: 2.967 train_accuracy: 15.178 validate_loss: 2.878 validate_accuracy: 17.696 epoch: 3 train_loss: 2.739 train_accuracy: 20.707 validate_loss: 2.666 validate_accuracy: 23.462 epoch: 4 train_loss: 2.599 train_accuracy: 24.797 validate_loss: 2.595 validate_accuracy: 25.230 epoch: 5 train_loss: 2.511 train_accuracy: 27.197 validate_loss: 2.544 validate_accuracy: 26.441 epoch: 6 train_loss: 2.420 train_accuracy: 29.370 validate_loss: 2.435 validate_accuracy: 30.245 epoch: 7 train_loss: 2.354 train_accuracy: 31.161 validate_loss: 2.444 validate_accuracy: 30.015 epoch: 8 train_loss: 2.284 train_accuracy: 33.343 validate_loss: 2.338 validate_accuracy: 32.401 epoch: 9 train_loss: 2.203 train_accuracy: 35.392 validate_loss: 2.336 validate_accuracy: 33.321 epoch: 10 train_loss: 2.126 train_accuracy: 37.133 validate_loss: 2.305 validate_accuracy: 34.096 epoch: 11 train_loss: 2.056 train_accuracy: 39.614 validate_loss: 2.316 validate_accuracy: 33.188 epoch: 12 train_loss: 2.004 train_accuracy: 41.306 validate_loss: 2.270 validate_accuracy: 34.133 epoch: 13 train_loss: 1.938 train_accuracy: 42.490 validate_loss: 2.270 validate_accuracy: 34.823 epoch: 14 train_loss: 1.887 train_accuracy: 44.181 validate_loss: 2.208 validate_accuracy: 36.955 epoch: 15 train_loss: 1.823 train_accuracy: 45.750 validate_loss: 2.248 validate_accuracy: 36.313 epoch: 16 train_loss: 1.773 train_accuracy: 46.641 validate_loss: 2.263 validate_accuracy: 37.621 epoch: 17 train_loss: 1.735 train_accuracy: 48.724 validate_loss: 2.278 validate_accuracy: 36.470 epoch: 18 train_loss: 1.671 train_accuracy: 50.120 validate_loss: 2.203 validate_accuracy: 37.984 epoch: 19 train_loss: 1.615 train_accuracy: 51.624 validate_loss: 2.309 validate_accuracy: 37.052 epoch: 20 train_loss: 1.557 train_accuracy: 53.985 validate_loss: 2.216 validate_accuracy: 38.748 epoch: 21 train_loss: 1.515 train_accuracy: 54.367 validate_loss: 2.300 validate_accuracy: 37.645 epoch: 22 train_loss: 1.469 train_accuracy: 55.711 validate_loss: 2.257 validate_accuracy: 39.087 epoch: 23 train_loss: 1.414 train_accuracy: 57.128 validate_loss: 2.335 validate_accuracy: 36.906 epoch: 24 train_loss: 1.368 train_accuracy: 58.728 validate_loss: 2.313 validate_accuracy: 38.033 epoch: 25 train_loss: 1.324 train_accuracy: 59.970 validate_loss: 2.422 validate_accuracy: 38.154 epoch: 26 train_loss: 1.285 train_accuracy: 60.793 validate_loss: 2.357 validate_accuracy: 38.299 epoch: 27 train_loss: 1.245 train_accuracy: 62.432 validate_loss: 2.369 validate_accuracy: 38.578 epoch: 28 train_loss: 1.222 train_accuracy: 62.913 validate_loss: 2.409 validate_accuracy: 37.330 epoch: 29 train_loss: 1.147 train_accuracy: 64.925 validate_loss: 2.483 validate_accuracy: 38.009 epoch: 30 train_loss: 1.137 train_accuracy: 65.243 validate_loss: 2.422 validate_accuracy: 40.007 epoch: 31 train_loss: 1.112 train_accuracy: 66.355 validate_loss: 2.556 validate_accuracy: 36.701 epoch: 32 train_loss: 1.075 train_accuracy: 67.172 validate_loss: 2.462 validate_accuracy: 37.028 epoch: 33 train_loss: 1.050 train_accuracy: 68.371 validate_loss: 2.510 validate_accuracy: 37.391 epoch: 34 train_loss: 0.990 train_accuracy: 69.576 validate_loss: 2.611 validate_accuracy: 38.178 epoch: 35 train_loss: 0.981 train_accuracy: 70.081 validate_loss: 2.607 validate_accuracy: 38.505 epoch: 36 train_loss: 0.933 train_accuracy: 71.182 validate_loss: 2.665 validate_accuracy: 38.832 epoch: 37 train_loss: 0.915 train_accuracy: 72.165 validate_loss: 2.571 validate_accuracy: 38.639 epoch: 38 train_loss: 0.907 train_accuracy: 72.809 validate_loss: 2.680 validate_accuracy: 38.675 epoch: 39 train_loss: 0.877 train_accuracy: 72.602 validate_loss: 2.735 validate_accuracy: 36.967 epoch: 40 train_loss: 0.850 train_accuracy: 74.082 validate_loss: 2.639 validate_accuracy: 37.924 epoch: 41 train_loss: 0.826 train_accuracy: 74.328 validate_loss: 2.842 validate_accuracy: 37.609 epoch: 42 train_loss: 0.820 train_accuracy: 74.784 validate_loss: 2.658 validate_accuracy: 39.050 epoch: 43 train_loss: 0.782 train_accuracy: 76.193 validate_loss: 2.801 validate_accuracy: 38.651 epoch: 44 train_loss: 0.766 train_accuracy: 76.498 validate_loss: 2.822 validate_accuracy: 38.869 epoch: 45 train_loss: 0.764 train_accuracy: 76.436 validate_loss: 2.869 validate_accuracy: 37.149 epoch: 46 train_loss: 0.728 train_accuracy: 77.632 validate_loss: 2.796 validate_accuracy: 37.827 epoch: 47 train_loss: 0.726 train_accuracy: 77.962 validate_loss: 2.849 validate_accuracy: 38.978 epoch: 48 train_loss: 0.717 train_accuracy: 77.835 validate_loss: 2.976 validate_accuracy: 37.209 epoch: 49 train_loss: 0.696 train_accuracy: 78.359 validate_loss: 2.927 validate_accuracy: 39.293 epoch: 50 train_loss: 0.654 train_accuracy: 79.946 validate_loss: 3.114 validate_accuracy: 37.088 epoch: 51 train_loss: 0.655 train_accuracy: 79.617 validate_loss: 2.909 validate_accuracy: 37.536 epoch: 52 train_loss: 0.660 train_accuracy: 79.706 validate_loss: 2.931 validate_accuracy: 38.857 epoch: 53 train_loss: 0.620 train_accuracy: 80.464 validate_loss: 3.119 validate_accuracy: 37.573 epoch: 54 train_loss: 0.597 train_accuracy: 81.561 validate_loss: 3.084 validate_accuracy: 38.094 epoch: 55 train_loss: 0.589 train_accuracy: 81.950 validate_loss: 3.116 validate_accuracy: 38.699 epoch: 56 train_loss: 0.612 train_accuracy: 80.886 validate_loss: 3.231 validate_accuracy: 37.754 epoch: 57 train_loss: 0.573 train_accuracy: 82.215 validate_loss: 3.139 validate_accuracy: 38.190 epoch: 58 train_loss: 0.577 train_accuracy: 81.740 validate_loss: 3.138 validate_accuracy: 37.875 epoch: 59 train_loss: 0.574 train_accuracy: 82.261 validate_loss: 3.119 validate_accuracy: 38.142 epoch: 60 train_loss: 0.556 train_accuracy: 82.640 validate_loss: 3.278 validate_accuracy: 37.779 epoch: 61 train_loss: 0.544 train_accuracy: 82.683 validate_loss: 3.191 validate_accuracy: 38.614 epoch: 62 train_loss: 0.534 train_accuracy: 83.802 validate_loss: 3.258 validate_accuracy: 37.548 epoch: 63 train_loss: 0.519 train_accuracy: 83.645 validate_loss: 3.264 validate_accuracy: 38.154 epoch: 64 train_loss: 0.522 train_accuracy: 83.460 validate_loss: 3.243 validate_accuracy: 38.069 epoch: 65 train_loss: 0.489 train_accuracy: 84.751 validate_loss: 3.396 validate_accuracy: 37.064 epoch: 66 train_loss: 0.500 train_accuracy: 84.600 validate_loss: 3.404 validate_accuracy: 38.953 epoch: 67 train_loss: 0.495 train_accuracy: 84.477 validate_loss: 3.234 validate_accuracy: 39.268 epoch: 68 train_loss: 0.502 train_accuracy: 84.532 validate_loss: 3.371 validate_accuracy: 38.009 epoch: 69 train_loss: 0.468 train_accuracy: 85.660 validate_loss: 3.432 validate_accuracy: 36.652 epoch: 70 train_loss: 0.476 train_accuracy: 85.306 validate_loss: 3.411 validate_accuracy: 37.258 epoch: 71 train_loss: 0.493 train_accuracy: 84.850 validate_loss: 3.443 validate_accuracy: 36.761 epoch: 72 train_loss: 0.467 train_accuracy: 85.370 validate_loss: 3.412 validate_accuracy: 38.227 epoch: 73 train_loss: 0.467 train_accuracy: 84.998 validate_loss: 3.360 validate_accuracy: 37.948 epoch: 74 train_loss: 0.462 train_accuracy: 85.404 validate_loss: 3.426 validate_accuracy: 38.057 epoch: 75 train_loss: 0.429 train_accuracy: 86.418 validate_loss: 3.613 validate_accuracy: 38.178 epoch: 76 train_loss: 0.439 train_accuracy: 86.138 validate_loss: 3.534 validate_accuracy: 36.834 epoch: 77 train_loss: 0.445 train_accuracy: 86.378 validate_loss: 3.481 validate_accuracy: 38.227 epoch: 78 train_loss: 0.444 train_accuracy: 86.061 validate_loss: 3.563 validate_accuracy: 37.016 epoch: 79 train_loss: 0.400 train_accuracy: 87.016 validate_loss: 3.649 validate_accuracy: 36.240 epoch: 80 train_loss: 0.413 train_accuracy: 87.186 validate_loss: 3.695 validate_accuracy: 37.452 epoch: 81 train_loss: 0.415 train_accuracy: 87.099 validate_loss: 3.598 validate_accuracy: 38.081 epoch: 82 train_loss: 0.407 train_accuracy: 87.737 validate_loss: 3.718 validate_accuracy: 37.536 epoch: 83 train_loss: 0.392 train_accuracy: 87.913 validate_loss: 3.673 validate_accuracy: 37.670 epoch: 84 train_loss: 0.392 train_accuracy: 87.861 validate_loss: 3.589 validate_accuracy: 39.147 epoch: 85 train_loss: 0.411 train_accuracy: 87.272 validate_loss: 3.622 validate_accuracy: 37.597 epoch: 86 train_loss: 0.392 train_accuracy: 87.777 validate_loss: 3.716 validate_accuracy: 37.173 epoch: 87 train_loss: 0.391 train_accuracy: 87.706 validate_loss: 3.713 validate_accuracy: 38.542 epoch: 88 train_loss: 0.366 train_accuracy: 88.240 validate_loss: 3.733 validate_accuracy: 38.469 epoch: 89 train_loss: 0.380 train_accuracy: 87.972 validate_loss: 3.832 validate_accuracy: 35.744 epoch: 90 train_loss: 0.387 train_accuracy: 88.332 validate_loss: 3.761 validate_accuracy: 38.348 epoch: 91 train_loss: 0.382 train_accuracy: 88.107 validate_loss: 3.777 validate_accuracy: 38.542 epoch: 92 train_loss: 0.352 train_accuracy: 88.671 validate_loss: 3.834 validate_accuracy: 37.270 epoch: 93 train_loss: 0.359 train_accuracy: 88.911 validate_loss: 3.713 validate_accuracy: 37.972 epoch: 94 train_loss: 0.375 train_accuracy: 88.613 validate_loss: 3.731 validate_accuracy: 38.106 epoch: 95 train_loss: 0.339 train_accuracy: 89.294 validate_loss: 3.757 validate_accuracy: 38.578 epoch: 96 train_loss: 0.364 train_accuracy: 88.650 validate_loss: 3.875 validate_accuracy: 37.718 epoch: 97 train_loss: 0.345 train_accuracy: 88.995 validate_loss: 3.580 validate_accuracy: 39.014 epoch: 98 train_loss: 0.334 train_accuracy: 89.586 validate_loss: 3.875 validate_accuracy: 37.561 epoch: 99 train_loss: 0.346 train_accuracy: 88.985 validate_loss: 3.789 validate_accuracy: 38.493 epoch: 100 train_loss: 0.338 train_accuracy: 89.534 validate_loss: 3.988 validate_accuracy: 37.779
Loss graph, top1 accuracy, confusion matrix and execution time for your model (say, mymodel and AlexNet)
# Your code here!
# Loss and accuracy graph
x_axis = np.arange(1,nepochs+1,1,int)
fig,axs=plt.subplots(1,2,figsize=(15,20),sharex=False,sharey=False)
fig.suptitle('Compare CNN and AlexNet in validate sets ')
axs[0].plot(x_axis,alexnet_compare_validate_loss,label='AlexNet validate_loss')
axs[0].plot(x_axis,CNN_compare_validate_loss,label='CNN validate_loss')
axs[1].plot(x_axis,alexnet_compare_validate_accuracy,label='AlexNet validate_accuracy')
axs[1].plot(x_axis, CNN_compare_validate_accuracy,label='CNN validate_accuracy')
axs[0].set_xlabel('epoch')
axs[1].set_xlabel('epoch')
axs[0].set_ylabel('loss')
axs[1].set_ylabel('percentage of accuracy')
axs[0].legend()
axs[1].legend()
plt.show()
# confusion matrix
num_class = len(classes)
nclasses = len(classes)
CNN_compare_model.load_state_dict(torch.load( './cnn_compare.pt'))
alexnet_compare.load_state_dict(torch.load( './alexnet_compare.pt'))
cnfm_cnn = np.zeros((nclasses,nclasses),dtype=int)
cnfm_alexnet = np.zeros((nclasses,nclasses),dtype=int)
with torch.no_grad():
for data in validate_loader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
CNN_outputs = CNN_compare_model(images)
_, CNN_predicted = torch.max(CNN_outputs, 1)
CNN_score_tmp = CNN_outputs
for i in range(labels.size(0)):
cnfm_cnn[labels[i].item(),CNN_predicted[i].item()] += 1
alexnet_outputs = alexnet_compare(images)
_, alexnet_predicted = torch.max(alexnet_outputs, 1)
alexnet_score_tmp = alexnet_outputs
for i in range(labels.size(0)):
cnfm_alexnet[labels[i].item(),alexnet_predicted[i].item()] += 1
print("CNN Model Confusion Matrix")
print(cnfm_cnn)
# show confusion matrix as a grey-level image
plt.imshow(cnfm_cnn, cmap='gray')
CNN Model Confusion Matrix [[28 0 0 5 1 1 2 4 7 3 2 1 3 2 0 0 1 0 7 1 5 0 1 12 0 0 2 0 2 0] [ 1 22 0 0 1 0 0 1 0 2 1 11 0 0 4 10 0 1 0 5 0 0 2 1 0 7 0 0 1 1] [ 0 1 25 0 6 0 1 3 1 0 29 1 0 2 1 2 2 0 0 6 2 0 1 0 1 1 2 1 2 1] [ 4 0 0 41 0 0 0 2 9 0 0 0 10 0 0 0 2 2 4 0 2 1 0 10 0 0 3 1 0 1] [ 0 1 2 0 84 0 0 0 0 0 2 0 0 2 0 0 0 1 0 2 1 0 1 0 1 0 3 1 3 0] [ 0 6 6 0 0 18 3 4 1 3 0 6 0 1 6 3 2 4 0 4 1 1 2 2 1 3 2 2 5 3] [ 4 0 2 2 0 1 32 1 2 1 1 2 2 2 6 0 1 0 5 1 5 2 3 2 1 0 5 4 5 0] [ 1 2 1 1 0 2 0 21 1 2 1 0 2 1 7 8 3 1 4 3 3 0 2 7 0 2 5 2 9 0] [ 9 0 1 13 0 0 0 3 20 1 3 0 8 1 0 1 1 0 10 0 3 2 2 2 0 0 1 3 4 1] [ 0 2 1 0 1 1 0 2 0 46 2 2 0 2 2 9 4 0 0 1 0 0 0 0 2 2 1 2 4 1] [ 3 1 7 3 3 1 0 1 2 0 49 1 2 10 0 2 1 2 1 3 1 0 0 0 0 0 0 0 1 3] [ 0 9 0 0 1 1 0 2 0 1 0 48 1 0 1 5 1 1 0 7 4 1 0 0 1 2 1 0 3 2] [ 2 1 1 5 0 0 1 4 1 0 5 0 37 1 0 0 2 1 3 0 3 2 3 9 0 0 2 2 2 1] [ 0 2 10 3 6 0 1 2 0 0 14 2 5 22 0 3 1 1 4 6 2 0 1 1 4 0 1 1 1 1] [ 0 2 2 0 0 4 5 4 1 1 4 2 1 0 35 0 8 0 1 1 1 0 6 1 1 2 0 3 2 3] [ 1 0 4 1 0 1 0 2 1 6 2 1 1 3 4 28 2 1 2 3 3 1 0 4 5 9 2 3 3 1] [ 0 0 0 3 0 0 0 3 0 5 1 1 2 0 6 4 24 1 3 2 0 0 2 0 0 0 2 6 5 2] [ 0 1 1 0 0 2 0 1 0 1 1 1 1 0 2 2 0 66 0 0 1 0 0 1 0 0 0 1 1 0] [ 5 0 3 0 0 0 2 3 1 0 0 0 4 0 0 0 0 0 59 0 1 2 0 3 0 1 6 1 1 1] [ 1 6 4 0 3 1 2 1 0 3 7 4 4 3 4 6 0 0 0 37 0 0 1 1 0 1 3 1 1 4] [14 2 3 2 1 1 1 2 1 0 0 0 5 1 2 3 0 0 5 0 31 5 0 5 4 0 3 2 2 0] [ 2 1 2 2 2 1 0 1 0 3 4 7 1 2 3 2 0 4 3 1 4 17 0 1 3 1 9 1 5 0] [ 0 0 0 5 0 1 1 4 4 0 1 0 7 2 5 0 3 1 1 2 0 0 32 0 0 0 3 4 3 0] [ 6 0 0 6 1 0 1 0 6 1 4 0 9 5 1 5 1 0 7 0 10 1 1 12 0 1 6 4 3 2] [ 0 6 2 0 0 1 0 0 1 1 2 0 0 2 0 8 0 0 0 0 1 1 0 0 51 1 2 1 7 0] [ 0 4 0 1 0 1 1 3 1 3 1 2 0 4 2 16 0 0 0 2 3 0 0 0 10 23 1 2 4 2] [ 5 0 2 3 2 1 0 2 0 3 0 0 1 1 1 3 0 0 6 2 1 3 1 2 0 0 55 0 0 0] [ 1 0 1 0 0 4 2 3 2 3 0 0 3 5 7 4 6 0 3 1 0 0 3 4 1 0 0 34 2 2] [ 1 1 4 4 2 1 2 3 1 0 4 0 0 1 8 6 3 2 2 1 0 4 2 6 2 3 3 0 30 4] [ 1 1 3 2 4 1 3 4 2 1 3 1 4 2 5 1 4 3 4 1 0 4 2 4 2 1 3 4 10 16]]
<matplotlib.image.AxesImage at 0x1a7662d9550>
print("AlexNet Model Confusion Matrix")
print(cnfm_alexnet)
# show confusion matrix as a grey-level image
plt.imshow(cnfm_alexnet, cmap='gray')
AlexNet Model Confusion Matrix [[42 0 0 5 0 1 2 3 4 0 1 0 2 0 0 0 1 0 10 0 7 0 0 8 0 0 1 0 0 3] [ 0 28 0 0 1 4 1 0 0 0 3 7 0 1 5 3 1 0 2 0 0 2 2 1 0 5 0 0 2 3] [ 0 0 58 1 1 0 1 4 0 0 9 1 0 3 1 0 2 0 2 0 0 2 0 1 0 0 2 0 0 3] [ 2 0 0 71 0 0 1 1 2 0 0 0 1 0 0 0 0 1 0 0 2 1 0 6 0 0 1 0 0 3] [ 0 1 5 0 93 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 2] [ 1 5 1 0 0 37 2 3 0 4 1 3 0 0 5 4 0 3 2 1 1 2 0 1 1 2 0 6 2 2] [ 0 3 0 2 0 1 60 2 1 0 0 2 0 1 2 0 0 3 5 0 1 1 2 1 0 0 0 0 1 4] [ 2 0 1 0 0 3 1 51 2 0 0 2 4 1 2 1 1 1 3 0 2 0 1 7 0 0 2 0 3 1] [ 6 0 0 7 0 0 3 0 56 1 0 0 4 0 0 0 1 0 3 0 1 0 0 3 0 0 2 0 0 2] [ 0 0 1 1 0 0 1 0 0 69 0 1 2 0 1 2 2 0 0 0 0 0 1 0 3 1 1 0 1 0] [ 1 1 21 1 0 0 2 1 2 0 48 2 1 4 1 0 0 2 0 3 1 0 0 1 0 2 1 0 1 1] [ 0 3 4 1 0 0 1 1 0 0 0 65 1 0 0 0 0 6 0 0 1 3 0 0 0 0 4 0 0 2] [ 5 0 1 1 1 0 1 0 1 0 0 0 52 0 3 0 1 2 0 0 1 2 2 6 0 0 1 1 2 5] [ 2 1 6 1 1 0 2 1 0 0 10 1 3 44 3 1 0 1 2 1 1 3 0 0 0 0 4 1 1 4] [ 0 0 1 0 0 5 3 2 0 2 0 2 0 2 57 0 2 2 2 0 0 0 1 0 0 0 1 3 2 3] [ 2 1 7 0 0 5 4 8 0 4 1 1 0 5 2 31 0 2 0 1 3 1 0 1 2 4 0 1 5 3] [ 0 1 0 0 0 0 3 1 1 0 0 1 0 1 3 2 49 4 0 0 0 0 0 0 0 0 0 2 1 3] [ 1 0 0 0 0 1 0 2 0 0 0 3 0 0 1 1 0 67 0 0 1 1 0 1 0 0 0 1 0 3] [ 8 0 1 0 0 0 2 2 3 0 0 0 0 0 0 1 0 2 61 0 1 0 0 6 0 0 4 0 1 1] [ 0 1 15 1 1 0 1 0 1 0 8 4 1 4 0 2 2 1 0 49 1 3 0 0 0 0 3 0 0 0] [ 7 0 0 1 0 1 2 5 7 0 0 0 2 1 0 0 0 0 1 0 52 3 0 8 0 2 2 0 0 1] [ 1 1 5 1 1 0 1 3 0 1 0 5 0 0 2 1 1 4 2 0 2 36 0 6 1 1 5 0 2 0] [ 0 0 0 1 0 0 1 0 4 0 1 0 7 0 2 0 1 2 1 0 0 3 49 5 0 0 0 1 0 1] [ 8 0 0 10 0 0 0 2 6 0 1 0 4 1 0 0 0 1 7 0 4 2 0 36 0 4 4 0 1 2] [ 0 1 3 0 0 2 1 1 0 1 1 1 0 1 1 7 0 0 0 0 0 1 0 0 56 5 0 1 3 1] [ 1 2 4 0 1 2 1 3 0 1 1 2 1 1 0 6 1 0 0 2 0 1 0 0 4 41 1 1 6 3] [ 1 0 3 0 0 0 1 0 0 0 0 1 0 1 0 0 0 1 2 3 0 2 0 3 0 0 75 0 0 1] [ 0 0 0 0 0 4 4 1 0 1 0 0 1 0 7 0 3 1 1 1 1 0 0 2 0 1 0 60 3 0] [ 1 3 2 1 0 7 3 6 1 3 0 3 1 0 2 2 4 3 4 0 1 1 1 4 2 2 0 0 37 6] [ 1 1 1 3 0 1 3 6 1 2 1 2 7 1 5 2 1 5 3 2 1 7 1 2 1 1 0 3 3 29]]
<matplotlib.image.AxesImage at 0x1a767883610>
alexnet_second = alexnet_running_time%60
alexnet_minute = int((alexnet_running_time-alexnet_second)/60)
alexnet_hour = int(alexnet_minute/60)
alexnet_minute = alexnet_minute%60
cnn_second = cnn_running_time%60
cnn_minute = int((cnn_running_time-cnn_second)/60)
cnn_hour = int(cnn_minute/60)
cnn_minute = cnn_minute%60
print("The train running time of AlexNet model in TinyImageNet30")
print(f"{alexnet_hour} hour {alexnet_minute} minute {alexnet_second :3f} second")
print("The train running time of CNN model in TinyImageNet30")
print(f"{cnn_hour} hour {cnn_minute} minute {cnn_second :3f} second")
The train running time of AlexNet model in TinyImageNet30 1 hour 44 minute 27.141047 second The train running time of CNN model in TinyImageNet30 1 hour 44 minute 27.369974 second
Please use TinyImageNet30 dataset for all results
Use an existing library to initiate grad-CAM
- To install: !pip install torchcam
- Call SmoothGradCAMpp: from torchcam.methods import SmoothGradCAMpp
- Apply to your model
You can see the details here: https://github.com/frgfm/torch-cam
It is recommended to first read the relevant paper Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, and refer to relevant course material.
HINT for displaying images with grad-CAM:
Display heatmap as a coloured heatmap superimposed onto the original image. We recommend the following steps to get a clear meaningful display.
From torchcam.utils import overlay_mask. But remember to resize your image, normalise it and put a 1 for the batch dimension (e.g, [1, 3, 224, 224])
# Your code here!
class cam_MyDataset(Dataset):
def __init__(self, data_type, transform=train_transformer):
'''
data_type : ["train_set", "test_set"]
'''
root_path = "./comp5625M_data_assessment_1/"
self.data_type = data_type
data_root = pathlib.Path(root_path+self.data_type+"/"+self.data_type)
if self.data_type == "train_set":
all_image_paths = list(data_root.glob("*/*"))
self.all_image_paths = all_image_paths
self.all_image_labels = [int(classes[path.parent.name]) for path in all_image_paths]
self.all_image_paths = [str(path) for path in all_image_paths]
self.transform = transform
else:
all_image_paths = list(data_root.glob("*/"))
self.all_image_paths = [str(path) for path in all_image_paths]
self.all_image_labels = [str(path) for path in all_image_paths]
self.transform = transform
def __getitem__(self, index):
img = cv.imread(self.all_image_paths[index])
org_img=torch.tensor(img)
img=self.transform(img)
label = self.all_image_labels[index]
return img,label,org_img
def __len__(self):
return len(self.all_image_paths)
cam_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
cam_train_set = cam_MyDataset("train_set",transform=cam_transform)
cam_train_loader = DataLoader(
cam_train_set,
batch_size = 1,
shuffle =True)
def get_cam_pic(model,train_iter,ifcorr):
model.eval()
cam_extractor = SmoothGradCAMpp(model)
n = 0
pics=[]
for data in train_iter:
images,labels,org_img = data
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
activation_map = cam_extractor(outputs.squeeze(0).argmax().item(), out)[0]
pre=torch.max(outputs.data,1)[1].cpu().numpy()[0]
true=labels.cpu().numpy()[0]
if ifcorr:#correct classification
if true==pre:
pics.append((images,org_img, activation_map))
if len(pics)>=4:#4 photoes
break
else:#wrong classification
if true!=pre:
pics.append((images,org_img, activation_map))
if len(pics)>=4:#4 photoes
break
return pics
import matplotlib.pyplot as plt
from torchcam.utils import overlay_mask
from torchvision.transforms.functional import normalize, resize
def plot_pic(p):
_,img,activation_map=p
img=img.cpu().numpy()[0,:,:,:]
img=resize(to_pil_image(img), (224, 224))
activation_map=activation_map.cpu().numpy()[0,:,:]
result = overlay_mask(img, to_pil_image(activation_map, mode='F'), alpha=0.5)
plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(np.array(result))
plt.axis('off');
plt.tight_layout()
plt.subplot(1, 2, 2)
plt.imshow(img)
plt.axis('off');
plt.tight_layout()
plt.show()
cam_cnn=CNN_compared()
cam_cnn.load_state_dict(torch.load('cnn_compare.pt'))
cam_cnn = cam_cnn .to(device)
cnn_pics=get_cam_pic(cam_cnn,cam_train_loader,1)#find correct photo
print(len(cnn_pics))
for p in cnn_pics:
plot_pic(p)
WARNING:root:no value was provided for `target_layer`, thus set to 'maxpool3'.
4
cam_alexnet = models.alexnet(pretrained=False)
com_num_fc =cam_alexnet.classifier[6].in_features
cam_alexnet.classifier[6] = torch.nn.Linear(in_features=com_num_fc, out_features=30)
cam_alexnet.load_state_dict(torch.load('alexnet_compare.pt'))
cam_alexnet = cam_alexnet.to(device)
alexnet_pics=get_cam_pic(cam_alexnet,cam_train_loader,0)#find wrong photo
print(len(alexnet_pics))
for p in alexnet_pics:
plot_pic(p)
WARNING:root:no value was provided for `target_layer`, thus set to 'avgpool'.
4
a) Why model predictions were correct or incorrect? You can support your case from 6.2
The grad-CAM shows which part or parts of the image contribute and have a greater impact on the classification of the image. For correct classification, the model focus tends to be on the more obvious and more representative features. This is especially true for images with simple scenes, a single number and no occlusions. Then, incorrect classifications are often due to the model being inattentive and not better at identifying boundaries, focusing on too much unnecessary and unrepresentative information.
b) What can you do to improve your results further?
Get more data, deepen model layers, data augmentation, increase number of training epochs, transfer learning, adjust parameters, vary learning rate, weight decay, dropout, batch regularization.